Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c7d08b7c
"include/vscode:/vscode.git/clone" did not exist on "f4ea00fc631e60b0b7abb1d0c454c51d0c6a2ecf"
Commit
c7d08b7c
authored
Dec 01, 2024
by
coderfeli
Browse files
use hasmainloop; no spill for 3tail
parent
532eb870
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
70 additions
and
97 deletions
+70
-97
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+2
-2
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+3
-51
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+2
-4
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+56
-39
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+5
-1
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
c7d08b7c
...
...
@@ -33,7 +33,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
Warp_Size
=
64
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
...
...
@@ -48,6 +47,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
// constexpr ck_tile::index_t Warp_Size = 64;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
// CDataType,
// M_Warp * N_Warp * K_Warp * Warp_Size,
...
...
@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
true
,
3
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
c7d08b7c
...
...
@@ -454,7 +454,7 @@ struct tile_window_linear
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
typename
DistributedTensor
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
auto
load
(
DistributedTensor
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
...
...
@@ -508,56 +508,8 @@ struct tile_window_linear
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load
(
dst_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
dst_tensor
;
}
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
c7d08b7c
...
...
@@ -222,13 +222,11 @@ struct BlockGemmARegBRegCRegV2
// Prefetch lds
template
<
typename
BlockWindow
,
typename
BlockTensor
>
CK_TILE_DEVICE
static
auto
PrefetchLds
(
const
BlockWindow
&
block_window
,
BlockTensor
&
block_tensor
)
CK_TILE_DEVICE
static
void
PrefetchLds
(
const
BlockWindow
&
block_window
,
BlockTensor
&
block_tensor
)
{
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
return
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
// load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
// return;
}
// C = A * B
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
c7d08b7c
...
...
@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
kHasHotLoop
=
Problem
::
kHasHotLoop
;
static
constexpr
auto
kTailNum
=
Problem
::
kTailNum
;
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// {
...
...
@@ -131,6 +133,18 @@ struct GemmPipelineAGmemBGmemCRegV1
0x008
,
num_mfma_inst
/
num_issue
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
// static_for<0, 8, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
// });
__builtin_amdgcn_sched_barrier
(
0
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockSubTile
()
{
...
...
@@ -261,60 +275,63 @@ struct GemmPipelineAGmemBGmemCRegV1
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile1
;
while
(
iCounter
>
1
)
{
// ping
if
(
kHasHotLoop
)
{
do
{
// ping
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
}
// pong
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
}
iCounter
-=
2
;
}
while
(
iCounter
>
1
);
}
//tail 3
if
(
kTailNum
==
3
)
{
// 3
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
}
//
pong
//
2
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
}
iCounter
-=
2
;
}
//tail 3
// if (iCounter == 1) {
// // 3
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
// LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
// LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// // 2
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// }
// //1
// {
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// //tail 2
// } else
//1
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
}
else
{
// //tail 2
{
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
c7d08b7c
...
...
@@ -32,6 +32,8 @@ struct GemmPipelineProblemBase
static
constexpr
bool
kPadM
=
GemmTraits
::
kPadM
;
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
static
constexpr
bool
kHasHotLoop
=
GemmTraits
::
HasHotLoop
;
static
constexpr
auto
kTailNum
=
GemmTraits
::
TailNum
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
c7d08b7c
...
...
@@ -12,7 +12,9 @@ template <bool kPadM_,
bool
kPadK_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
>
typename
CLayout_
,
bool
HasHotLoop_
,
index_t
TailNum_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
...
...
@@ -24,6 +26,8 @@ struct TileGemmTraits
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment