Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ec959387
Unverified
Commit
ec959387
authored
Feb 13, 2025
by
rocking
Committed by
GitHub
Feb 13, 2025
Browse files
Merge branch 'develop' into ck_tile/fmha_receipt_aiter
parents
c1e2fef7
0e5e29c4
Changes
393
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2394 additions
and
479 deletions
+2394
-479
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+2
-1
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+634
-59
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
+52
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+59
-29
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+5
-2
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+21
-7
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+208
-69
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+237
-46
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+18
-8
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+39
-21
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+164
-35
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+559
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
...ipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
+92
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+255
-61
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+25
-19
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+7
-117
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+11
-3
No files found.
include/ck_tile/ops/fused_moe.hpp
View file @
ec959387
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...
...
@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
ec959387
...
...
@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
ec959387
...
...
@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
...
...
@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
...
...
@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST
constexpr
auto
moe_sorting_get_smem_row_col
(
int
num_tokens_
,
int
num_experts_
)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int
smem_cols
=
num_experts_
+
1
;
// usually experts is power of 2. padding here
int
smem_rows
=
[
&
](){
index_t
target_occupancy_
=
2
;
constexpr
index_t
total_
=
65536
/
sizeof
(
int
);
constexpr
index_t
sub_unroll
=
8
;
constexpr
index_t
cumsum_bufs
=
2
;
// 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if
((
total_
/
target_occupancy_
)
<
((
cumsum_bufs
+
sub_unroll
)
*
smem_cols
))
{
if
((
total_
/
1
)
<
((
cumsum_bufs
+
sub_unroll
)
*
smem_cols
))
throw
std
::
runtime_error
(
"too many num_experts, can't allocate smem"
);
target_occupancy_
=
1
;
}
int
r
=
total_
/
target_occupancy_
/
smem_cols
;
// round to sub_unroll multipl
int
r_for_sub_token
=
r
-
cumsum_bufs
;
r_for_sub_token
=
min
(
r_for_sub_token
,
num_tokens_
);
r_for_sub_token
=
(
r_for_sub_token
+
sub_unroll
-
1
)
/
sub_unroll
*
sub_unroll
;
r_for_sub_token
=
max
(
r_for_sub_token
,
1
);
if
(
r_for_sub_token
>
1
)
{
int
r_unroll_
=
r_for_sub_token
/
sub_unroll
;
// round to 1x/2x/4x/8x number of sub_unroll
int
clz_
=
__builtin_clz
(
r_unroll_
);
// 0b1:31 0b2:30, 0b3:30, 0b4:29
int
mask_
=
(
1
<<
(
31
-
clz_
))
-
1
;
mask_
=
mask_
>
0b111
?
0b111
:
mask_
;
//clamp to 8x at most
mask_
=
~
mask_
;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token
=
(
r_unroll_
&
mask_
)
*
sub_unroll
;
}
// final check
if
(
(
r_for_sub_token
+
cumsum_bufs
*
smem_cols
*
target_occupancy_
)
>=
total_
)
{
throw
std
::
runtime_error
(
"can't run this kernel, request LDS over size"
);
}
return
r_for_sub_token
+
cumsum_bufs
;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return
ck_tile
::
make_tuple
(
smem_rows
,
smem_cols
);
}
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
// [token, topk]
const
void
*
p_weights
;
// [token, topk]
const
void
*
p_local_expert_mask
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
...
...
@@ -101,6 +203,7 @@ struct MoeSortingKernel
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
const
void
*
p_local_expert_mask
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
...
...
@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t
moe_buf_bytes
;
index_t
tokens_per_thread
;
index_t
smem_rows
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
mdiv
expert_mdiv
;
// mdiv sub_tokens_mdiv;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
...
...
@@ -123,15 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
#if MOE_SORTING_USE_EX_KERNEL
(
void
)
h
;
return
dim3
(
256
);
#else
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
#endif
}
// in byte
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
#if MOE_SORTING_USE_EX_KERNEL
auto
[
smem_rows
,
smem_cols
]
=
moe_sorting_get_smem_row_col
(
h
.
tokens
,
h
.
num_experts
);
return
smem_rows
*
smem_cols
*
sizeof
(
int
);
#else
const
auto
blocks
=
BlockSize
(
h
);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return
((
blocks
.
x
+
1
)
*
(
h
.
num_experts
+
1
)
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
#endif
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
...
@@ -139,6 +255,7 @@ struct MoeSortingKernel
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
p_local_expert_mask
=
h
.
p_local_expert_mask
;
k
.
p_sorted_token_ids
=
h
.
p_sorted_token_ids
;
k
.
p_sorted_weights
=
h
.
p_sorted_weights
;
k
.
p_sorted_expert_ids
=
h
.
p_sorted_expert_ids
;
...
...
@@ -152,10 +269,18 @@ struct MoeSortingKernel
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
k
.
smem_rows
=
[
&
](){
auto
[
r_
,
c_
]
=
moe_sorting_get_smem_row_col
(
h
.
tokens
,
h
.
num_experts
);
(
void
)
c_
;
return
r_
;
}();
k
.
expert_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
num_experts
)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return
k
;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template
<
typename
data_t
,
int
wave_size
>
__device__
inline
void
wave_cumsum
(
data_t
&
thread_data
)
const
{
...
...
@@ -196,6 +321,40 @@ struct MoeSortingKernel
bank_mask
,
bound_ctrl
)));
// row_shr:4
}
if
constexpr
(
wave_size
==
8
)
{
// wave-size=8 need one extra shift
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x118
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t
xxx
=
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x157
,
row_mask
,
bank_mask
,
bound_ctrl
));
// row_newbcast:7
data_t
yyy
=
(
__lane_id
()
/
8
)
%
2
==
0
?
0
:
xxx
;
thread_data
=
thread_data
-
yyy
;
#endif
}
if
constexpr
(
wave_size
>
8
)
{
thread_data
=
...
...
@@ -224,6 +383,36 @@ struct MoeSortingKernel
}
}
// reduce single pixel within a wave
template
<
typename
T
,
typename
F
,
index_t
wave_size_
=
warpSize
>
__device__
static
constexpr
T
wave_reduce
(
T
local
,
F
reduce_f
,
number
<
wave_size_
>
=
{})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr
int
reduce_stage
=
[](){
if
constexpr
(
wave_size_
==
2
)
return
1
;
else
if
constexpr
(
wave_size_
==
4
)
return
2
;
else
if
constexpr
(
wave_size_
==
8
)
return
3
;
else
if
constexpr
(
wave_size_
==
16
)
return
4
;
else
if
constexpr
(
wave_size_
==
32
)
return
5
;
else
if
constexpr
(
wave_size_
==
64
)
return
6
;
else
return
0
;
}();
// clang-format on
T
v_local
=
local
;
#pragma unroll reduce_stage
for
(
int
i_stage
=
0
;
i_stage
<
reduce_stage
;
i_stage
++
)
{
int
src_lane
=
__lane_id
()
^
(
1
<<
i_stage
);
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
T
v_remote
=
bit_cast
<
T
>
(
v_remote_tmp
);
v_local
=
reduce_f
(
v_local
,
v_remote
);
}
return
v_local
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
...
...
@@ -257,37 +446,37 @@ struct MoeSortingKernel
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
#if 1
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
index_t
local_c
[
8
];
index_t
prev_c
=
0
;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
{
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
+=
prev_c
;
local_c
[
1
]
+=
local_c
[
0
];
...
...
@@ -299,51 +488,57 @@ struct MoeSortingKernel
local_c
[
7
]
+=
local_c
[
6
];
prev_c
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
}
}
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{
if
(
tid
<
num_experts
)
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
index_t
local_c
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
wave_cumsum
<
int
,
64
>
(
local_c
[
j
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
}
}
}
#endif
__syncthreads
();
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
...
...
@@ -351,20 +546,24 @@ struct MoeSortingKernel
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
padded_tokens_per_expert
=
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
int
local_cumsum
=
padded_tokens_per_expert
;
int
local_cumsum
=
padded_tokens_per_expert
;
wave_cumsum
<
int
,
64
>
(
local_cumsum
);
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
*
p_total_tokens_post_pad
=
local_cumsum
;
}
if
(
tid
<
num_experts
)
{
if
(
tid
<
num_experts
)
{
cumsum
[
tid
+
1
]
=
local_cumsum
;
}
}
...
...
@@ -373,7 +572,7 @@ struct MoeSortingKernel
if
(
tid
<
num_experts
)
{
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
int
e_end
=
cumsum
[
tid
+
1
];
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
...
...
@@ -383,8 +582,8 @@ struct MoeSortingKernel
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
expert_id
=
topk_id
[
i
];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
rank_post_pad
=
local_cnt
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
...
...
@@ -393,16 +592,17 @@ struct MoeSortingKernel
#else
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
}
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
index_t
expert_end
=
cumsum
[
tid
+
1
];
while
(
expert_offset
<
expert_end
)
{
...
...
@@ -417,16 +617,19 @@ struct MoeSortingKernel
}
}
}
else
{
else
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
// TODO: only support expert-tile like 8, 16, 32
static
constexpr
index_t
experts_per_wave
=
warpSize
/
Problem
::
ExpertTile
;
{
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_end
=
cumsum
[
eid
+
1
];
if
(
eid
<
num_experts
)
{
if
(
eid
<
num_experts
)
{
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
...
...
@@ -436,10 +639,363 @@ struct MoeSortingKernel
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
+=
experts_per_wave
;
expert_offset
+=
experts_per_wave
;
}
}
}
}
}
// only support index_t, and single pixel access
struct
simple_smem_indexer
{
index_t
*
smem
;
index_t
row_stride
;
// this is 2D
CK_TILE_DEVICE
simple_smem_indexer
(
index_t
*
smem_
,
index_t
row_stride_
)
:
smem
(
smem_
),
row_stride
(
row_stride_
)
{
}
CK_TILE_DEVICE
const
index_t
&
operator
()(
index_t
i_row
,
index_t
i_col
)
const
{
return
smem
[
i_row
*
row_stride
+
i_col
];
}
CK_TILE_DEVICE
index_t
&
operator
()(
index_t
i_row
,
index_t
i_col
)
{
return
smem
[
i_row
*
row_stride
+
i_col
];
}
// this is 1D or linear
CK_TILE_DEVICE
simple_smem_indexer
(
index_t
*
smem_
)
:
smem
(
smem_
),
row_stride
(
0
)
{}
CK_TILE_DEVICE
const
index_t
&
operator
()(
index_t
idx
)
const
{
return
smem
[
idx
];
}
CK_TILE_DEVICE
index_t
&
operator
()(
index_t
idx
)
{
return
smem
[
idx
];
}
};
CK_TILE_DEVICE
void
moe_align_block_size_kernel_ex
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
const
IndexType
*
__restrict__
local_expert_mask
,
index_t
*
p_sorted_token_ids
,
WeightType
*
p_sorted_weights
,
index_t
*
p_sorted_expert_ids
,
index_t
*
p_total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
tokens
,
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
const
mdiv
expert_mdiv
,
const
index_t
smem_rows
,
void
*
smem
)
const
{
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
wid
=
__builtin_amdgcn_readfirstlane
(
tid
/
warpSize
);
const
index_t
lid
=
__lane_id
();
constexpr
index_t
block_size
=
256
;
// blockDim.x;
const
index_t
sub_tokens
=
smem_rows
-
2
;
// sub_tokens_mdiv.divisor;
const
index_t
topk
=
topk_mdiv
.
divisor
;
auto
f_sum
=
[](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
const
index_t
smem_cols
=
num_experts
+
1
;
simple_smem_indexer
smem_cumsum
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
0
};
simple_smem_indexer
smem_cumdup
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
smem_cols
};
simple_smem_indexer
smem_tokens
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
2
*
smem_cols
,
smem_cols
};
// #pragma unroll 8
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
num_experts
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_expert_id
;
expert_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_expert_id
);
smem_tokens
(
curr_token_id
,
curr_expert_id
)
=
0
;
}
__syncthreads
();
for
(
int
i_token
=
0
;
i_token
<
tokens
;
i_token
+=
sub_tokens
)
{
// NOTE: below for loop can't have barrier inside!!
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
topk
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
int
i_t
=
i_token
+
curr_token_id
;
if
(
i_t
<
tokens
)
{
int
eid
=
topk_id
[
i_t
*
topk
+
curr_topk_id
];
if
constexpr
(
Problem
::
SubTokenOneShot
)
smem_tokens
(
curr_token_id
,
eid
)
=
curr_topk_id
+
1
;
else
smem_tokens
(
curr_token_id
,
eid
)
++
;
}
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
}
__syncthreads
();
// make sure different i_token iteration not overlap by different wave
}
// counting
if
(
tid
==
0
)
{
smem_cumsum
(
0
)
=
0
;
// smem_cumdup(0) = 0;
}
{
constexpr
int
lane_group_sz
=
8
;
int
lane_group_id
=
tid
/
lane_group_sz
;
int
lane_group_os
=
tid
%
lane_group_sz
;
constexpr
int
lane_group_nm
=
block_size
/
lane_group_sz
;
for
(
int
i_e
=
lane_group_id
;
i_e
<
num_experts
;
i_e
+=
lane_group_nm
)
{
index_t
local_c
[
Problem
::
SubTokenTile
];
index_t
cnt
=
0
;
for
(
int
i
=
0
;
i
<
sub_tokens
;
i
+=
8
*
Problem
::
SubTokenTile
)
{
#pragma unroll Problem::SubTokenTile
for
(
int
j
=
0
;
j
<
Problem
::
SubTokenTile
;
j
++
)
{
local_c
[
j
]
=
smem_tokens
(
i
+
j
*
8
+
lane_group_os
,
i_e
);
if
constexpr
(
Problem
::
SubTokenOneShot
)
{
local_c
[
j
]
=
local_c
[
j
]
!=
0
?
1
:
0
;
}
}
#pragma unroll Problem::SubTokenTile
for
(
int
j
=
0
;
j
<
Problem
::
SubTokenTile
;
j
++
)
{
cnt
+=
wave_reduce
(
local_c
[
j
],
f_sum
,
number
<
8
>
{});
}
}
if
(
lane_group_os
==
0
)
smem_cumsum
(
i_e
+
1
)
=
cnt
;
}
}
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
smem_cumdup
(
0
)
=
0
;
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
// reuse this buffer
smem_cumdup
(
i_e
+
1
)
=
local_expert_mask
[
i_e
];
}
}
__syncthreads
();
{
if
(
wid
==
0
)
{
// NOTE: under this block can never use __syncthreads!
int
i_e_
=
0
;
int
local_cumsum_
=
0
;
for
(;
i_e_
<
num_experts
;
i_e_
+=
warpSize
)
{
int
pre_cumsum_
=
smem_cumsum
(
lid
==
0
?
i_e_
:
0
);
int
local_cnt
=
smem_cumsum
(
i_e_
+
lid
+
1
);
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
pre_cumsum_masking
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
return
smem_cumdup
(
lid
==
0
?
i_e_
:
0
);
else
return
0
;
// not used
}();
int
local_masking
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
return
smem_cumdup
(
i_e_
+
lid
+
1
);
else
return
0
;
// not used
}();
int
padded_tokens_per_expert
=
[
&
]()
{
int
x_
=
[
&
]()
{
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return
blocks_pers_expert
*
unit_size_mdiv
.
divisor
;
}
else
{
return
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
}
}();
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
return
local_masking
?
x_
:
0
;
}
else
return
x_
;
}();
local_cumsum_
=
padded_tokens_per_expert
;
local_cumsum_
+=
pre_cumsum_
;
// note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum
<
int
,
warpSize
>
(
local_cumsum_
);
if
((
i_e_
+
lid
)
<
num_experts
)
smem_cumsum
(
i_e_
+
lid
+
1
)
=
local_cumsum_
;
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
local_masking
+=
pre_cumsum_masking
;
wave_cumsum
<
int
,
warpSize
>
(
local_masking
);
if
((
i_e_
+
lid
)
<
num_experts
)
smem_cumdup
(
i_e_
+
lid
+
1
)
=
local_masking
;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
}
if
((
lid
+
i_e_
-
warpSize
)
==
(
num_experts
-
1
))
{
*
p_total_tokens_post_pad
=
local_cumsum_
;
}
}
__syncthreads
();
}
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
i_e
);
int
e_end
=
smem_cumsum
(
i_e
+
1
);
int
expert_id
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
// local expert id from cumsum
return
smem_cumdup
(
i_e
);
}
else
return
i_e
;
}();
smem_cumdup
(
i_e
)
=
e_start
;
// duplicate cumsum for later use
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
if
(
e_start
==
e_end
)
// skip zero token expert
continue
;
}
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
if
(
local_expert_mask
[
i_e
]
==
0
)
continue
;
}
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
expert_id
;
}
}
smem_cumdup
(
num_experts
)
=
smem_cumsum
(
num_experts
);
// fill the p_sorted_token_ids/p_sorted_weights
for
(
int
i_token
=
0
;
i_token
<
tokens
;
i_token
+=
sub_tokens
)
{
if
constexpr
(
!
Problem
::
SubTokenOneShot
)
{
// clear every time
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
num_experts
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_expert_id
;
expert_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_expert_id
);
smem_tokens
(
curr_token_id
,
curr_expert_id
)
=
0
;
}
__syncthreads
();
// load again
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
topk
);
i
+=
block_size
)
{
uint32_t
curr_token_id_
,
curr_topk_id_
;
topk_mdiv
.
divmod
(
i
,
curr_token_id_
,
curr_topk_id_
);
int
curr_token_id
=
static_cast
<
int
>
(
curr_token_id_
);
int
curr_topk_id
=
static_cast
<
int
>
(
curr_topk_id_
);
int
i_t
=
i_token
+
curr_token_id
;
if
(
i_t
<
tokens
)
{
int
eid
=
topk_id
[
i_t
*
topk
+
curr_topk_id
];
smem_tokens
(
curr_token_id
,
eid
)
=
curr_topk_id
+
1
;
// at least 1
}
}
__syncthreads
();
}
{
constexpr
int
lane_group_sz
=
8
;
int
lane_group_id
=
tid
/
lane_group_sz
;
int
lane_group_os
=
tid
%
lane_group_sz
;
constexpr
int
lane_group_nm
=
block_size
/
lane_group_sz
;
for
(
int
eid
=
lane_group_id
;
eid
<
num_experts
;
eid
+=
lane_group_nm
)
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
if
(
local_expert_mask
[
eid
]
==
0
)
continue
;
}
int
position
=
smem_cumsum
(
eid
);
for
(
int
i_sub_token
=
lane_group_os
;
i_sub_token
<
sub_tokens
;
i_sub_token
+=
lane_group_sz
)
{
auto
x
=
smem_tokens
(
i_sub_token
,
eid
);
int
local_cnt_cache
=
x
!=
0
?
1
:
0
;
int
local_cnt
=
local_cnt_cache
;
wave_cumsum
<
int
,
lane_group_sz
>
(
local_cnt
);
if
(
x
!=
0
)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
position
+
local_cnt
-
1
]
=
MOE_SORTING_MOCK_ID
(
i_token
+
i_sub_token
,
x
-
1
);
#else
p_sorted_token_ids
[
position
+
local_cnt
-
1
]
=
i_token
+
i_sub_token
;
#endif
p_sorted_weights
[
position
+
local_cnt
-
1
]
=
weights
[(
i_token
+
i_sub_token
)
*
topk
+
x
-
1
];
}
int
remote_cnt
=
__builtin_amdgcn_ds_bpermute
(
(
lane_group_sz
*
(
lane_group_id
+
1
)
-
1
)
<<
2
,
local_cnt
);
position
+=
remote_cnt
;
}
smem_cumsum
(
eid
)
=
position
;
}
}
}
__syncthreads
();
}
// add the skip number
for
(
int
eid
=
tid
;
eid
<
num_experts
;
eid
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
eid
);
int
e_end
=
smem_cumdup
(
eid
+
1
);
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
if
(
e_start
==
e_end
)
// skip zero token expert
continue
;
}
while
(
e_start
<
e_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
e_start
]
=
MOE_SORTING_MOCK_ID
(
tokens
,
topk
);
#else
p_sorted_token_ids
[
e_start
]
=
tokens
;
#endif
p_sorted_weights
[
e_start
]
=
static_cast
<
WeightType
>
(
0.0
);
e_start
++
;
}
}
}
...
...
@@ -456,6 +1012,24 @@ struct MoeSortingKernel
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
extern
__shared__
char
smem
[];
#if MOE_SORTING_USE_EX_KERNEL
(
void
)
numel
;
return
moe_align_block_size_kernel_ex
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
const
IndexType
*>
(
kargs
.
p_local_expert_mask
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
p_sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
p_total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
tokens
,
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
kargs
.
expert_mdiv
,
kargs
.
smem_rows
,
smem
);
#else
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
...
...
@@ -468,6 +1042,7 @@ struct MoeSortingKernel
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
#endif
}
};
...
...
include/ck_tile/ops/fused_moe/
pipeli
ne/moe_sorting_problem.hpp
→
include/ck_tile/ops/fused_moe/
ker
ne
l
/moe_sorting_problem.hpp
View file @
ec959387
...
...
@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
SubTokenTile_
,
// 1,2,4,8, or 0 in the future
bool
SubTokenOneShot_
,
// if we only loop over once or not
bool
LocalExpertMasking_
,
// used in EP case
bool
SkipExpertsWithZeroTokens_
=
true
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
SubTokenTile
=
SubTokenTile_
;
static
constexpr
bool
SubTokenOneShot
=
SubTokenOneShot_
;
static
constexpr
bool
LocalExpertMasking
=
LocalExpertMasking_
;
static
constexpr
bool
SkipExpertsWithZeroTokens
=
SkipExpertsWithZeroTokens_
;
static_assert
(
SubTokenTile
==
1
||
SubTokenTile
==
2
||
SubTokenTile
==
4
||
SubTokenTile
==
8
);
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
ec959387
...
...
@@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
...
...
@@ -46,3 +48,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
ec959387
...
...
@@ -14,24 +14,54 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
private:
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
};
public:
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Traits
=
GemmTraits_
<
Problem
,
Policy
>
;
using
WarpGemm
=
typename
Traits
::
WarpGemm
;
using
BlockGemmShape
=
typename
Traits
::
BlockGemmShape
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
...
...
@@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
W
G
::
AWarpDstrEncoding
{});
a_block_outer_dstr_encoding
,
typename
W
arpGemm
::
AWarpDstrEncoding
{});
return
a_block_dstr_encode
;
}
...
...
@@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
W
G
::
BWarpDstrEncoding
{});
b_block_outer_dstr_encoding
,
typename
W
arpGemm
::
BWarpDstrEncoding
{});
return
b_block_dstr_encode
;
}
...
...
@@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
return
c_block_dstr_encode
;
}
...
...
@@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
W
G
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
G
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
G
::
CWarpDstr
;
using
AWarpDstr
=
typename
W
arpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
arpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
arpGemm
::
CWarpDstr
;
using
AWarpTensor
=
typename
W
G
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
G
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
G
::
CWarpTensor
;
using
AWarpTensor
=
typename
W
arpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
arpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
arpGemm
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
...
...
@@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
W
G
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
W
arpGemm
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
...
...
@@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
ec959387
...
...
@@ -79,8 +79,11 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
WarpGemm
::
kK
*
KPack
;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static
constexpr
index_t
KPack
=
8
;
static
constexpr
index_t
KPerThread
=
KIterPerWarp
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_batched"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
index_t
batch_stride_A
;
...
...
@@ -70,7 +84,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
*
batch_count
);
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
)
,
batch_count
,
KBatch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
...
...
@@ -101,14 +115,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_
k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
auto
i_
splitk
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_k
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_
split
k
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
...
...
@@ -128,7 +142,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KB
atch
==
1
)
if
(
kargs
.
k_b
atch
==
1
)
{
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
ec959387
...
...
@@ -8,7 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/
ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler
.hpp"
#include "ck_tile/
host/concat
.hpp"
namespace
ck_tile
{
...
...
@@ -69,18 +69,26 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// Below type is actually accumulation data type - the output of block GEMM.
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
static
constexpr
auto
I0
=
number
<
0
>
();
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
(
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
// clang-format off
return
concat
(
'_'
,
"gemm"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
GemmPipeline
::
GetName
());
// clang-format on
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
GemmKernelArgs
{
...
...
@@ -93,7 +101,7 @@ struct GemmKernel
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
KB
atch
;
index_t
k_b
atch
;
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
...
...
@@ -121,7 +129,7 @@ struct GemmKernel
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
kargs
.
KB
atch
*
K1
;
const
index_t
K_t
=
kargs
.
k_b
atch
*
K1
;
const
index_t
KRead
=
(
kargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -142,13 +150,13 @@ struct GemmKernel
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
KB
atch
-
1
))
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
k_b
atch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KB
atch
-
1
);
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
k_b
atch
-
1
);
}
}
...
...
@@ -159,15 +167,12 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
(
!
((
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
)
||
!
(
std
::
is_same_v
<
CDataType
,
fp16_t
>
||
std
::
is_same_v
<
CDataType
,
bf16_t
>
)))
if
constexpr
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
if
(
kargs
.
KB
atch
!=
1
)
if
(
kargs
.
k_b
atch
!=
1
)
{
std
::
cerr
<<
"Conditions not met for Kbatch >1 !"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -176,10 +181,14 @@ struct GemmKernel
{
if
(
kargs
.
K
%
TilePartitioner
::
KPerBlock
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
std
::
cerr
<<
"Can't support K that is not a multiple of KPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -187,10 +196,14 @@ struct GemmKernel
{
if
(
kargs
.
M
%
TilePartitioner
::
MPerBlock
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
std
::
cerr
<<
"Can't support M that is not a multiple of MPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
M
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -199,10 +212,14 @@ struct GemmKernel
{
if
(
kargs
.
N
%
TilePartitioner
::
NPerBlock
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
std
::
cerr
<<
"Can't support N that is not a multiple of NPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
N
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -210,10 +227,14 @@ struct GemmKernel
{
if
(
kargs
.
K
%
TilePartitioner
::
KPerBlock
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
std
::
cerr
<<
"Can't support K that is not a multiple of KPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -222,10 +243,14 @@ struct GemmKernel
{
if
(
kargs
.
N
%
TilePartitioner
::
NPerBlock
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
std
::
cerr
<<
"Can't support N that is not a multiple of NPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
N
%
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -233,10 +258,14 @@ struct GemmKernel
{
if
(
kargs
.
M
%
TilePartitioner
::
MPerBlock
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
std
::
cerr
<<
"Can't support M that is not a multiple of MPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
M
%
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
}
}
...
...
@@ -257,16 +286,16 @@ struct GemmKernel
a_ptr
,
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
M
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
GetVectorSizeA
()
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -276,9 +305,9 @@ struct GemmKernel
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -287,11 +316,12 @@ struct GemmKernel
b_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
}
}();
// TODO: enable vector write for C in ColMajor
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -299,7 +329,7 @@ struct GemmKernel
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
Gemm
Pipeline
::
VectorSizeC
>
{},
number
<
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -331,9 +361,9 @@ struct GemmKernel
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
M
PerBlock
>
{},
number
<
TilePartitioner
::
K
PerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
make_tuple
(
number
<
TilePartitioner
::
K
PerBlock
>
{},
number
<
TilePartitioner
::
M
PerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadM
>
{});
}
}();
...
...
@@ -349,12 +379,13 @@ struct GemmKernel
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
N
PerBlock
>
{},
number
<
TilePartitioner
::
K
PerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
make_tuple
(
number
<
TilePartitioner
::
K
PerBlock
>
{},
number
<
TilePartitioner
::
N
PerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
}();
// TODO vector write in for C in ColMajor
const
auto
&
c_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
@@ -380,20 +411,45 @@ struct GemmKernel
CK_TILE_DEVICE
static
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
auto
c_block_window
=
make_tile_window
(
const
auto
&
a_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
}
else
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
MPerBlock
>
{}),
{
0
,
i_m
});
}
}();
const
auto
&
b_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
}
else
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
0
,
i_n
});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
...
...
@@ -407,7 +463,9 @@ struct GemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
...
...
@@ -417,7 +475,7 @@ struct GemmKernel
CK_TILE_DEVICE
static
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr
,
void
*
smem_ptr
_0
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
...
...
@@ -435,28 +493,72 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
);
}
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemm2LDS
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
__restrict__
smem_ptr_0
,
void
*
__restrict__
smem_ptr_1
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
,
smem_ptr_1
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
...
...
@@ -469,16 +571,53 @@ struct GemmKernel
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr_0
[
GetSmemSize
()];
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
if
(
kargs
.
KB
atch
==
1
)
if
(
kargs
.
k_b
atch
==
1
)
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
RunGemm2LDS
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
RunGemm2LDS
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
/**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template
<
typename
BlockGemmShapeType
>
struct
GemmTile2DPartitioner
{
...
...
@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
/** @brief Returns 3D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
([[
maybe_unused
]]
index_t
M
,
[[
maybe_unused
]]
index_t
N
)
noexcept
;
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
{
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
return
dim3
(
GridDimX
,
GridDimY
,
1
);
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
...
...
@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
/**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
);
...
...
@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
};
/**
* @brief Struct representing 1D block index mapping into 2D output tile space.
* @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/
template
<
typename
BlockGemmShape
Type
>
template
<
typename
BlockGemmShape
_
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
Type
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
_
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
/** @brief delete default ctr with no any object */
constexpr
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief constructs an object that does contain a N value. */
constexpr
GemmTile1DPartitioner
(
index_t
N
)
noexcept
{
N_
=
N
;
}
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief Returns 1D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
/**
* @brief Construct a new GemmTile1DPartitioner object.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
([[
maybe_unused
]]
index_t
M
,
index_t
N
)
noexcept
{
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
N_
=
N
;
}
/**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
noexcept
->
index_t
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
GridDimX
*
GridDimY
;
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
NBlock
=
GetN
Block
(
N_
);
const
index_t
NBlock
s
=
integer_divide_ceil
(
N_
,
NPer
Block
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
(
iM
)
*
NBlock
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
s
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
iM
*
NBlock
s
);
return
make_tuple
(
iM
,
iN
);
}
...
...
@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template
<
typename
Partitioner
Fn
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
Partitioner
Fn
>{}
>>
template
<
typename
Tile
Partitioner
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
Tile
Partitioner
>{}
>>
struct
OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
* @param [in] block_start Workgroup offset.
* @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/
[[
nodiscard
]]
CK_TILE_DEVICE
static
constexpr
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
N
)
noexcept
[[
nodiscard
]]
CK_TILE_DEVICE
static
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
M
,
index_t
N
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
[
iM
,
iN
]
=
Partitioner
Fn
(
N
)
.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
const
auto
[
iM
,
iN
]
=
Tile
Partitioner
{
M
,
N
}
.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
return
make_tuple
(
iM
,
iN
);
}
};
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template
<
typename
BlockGemmShapeType
,
index_t
GroupNum
,
index_t
M01
>
struct
GemmSpatiallyLocalTilePartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShapeType
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
(
index_t
M_
,
index_t
N_
)
noexcept
:
M
(
M_
),
N
(
N_
)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
const
index_t
GridDimX
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
index_t
GridDimY
=
integer_divide_ceil
(
N
,
NPerBlock
);
return
GridDimX
*
GridDimY
;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
auto
GetOutputTileIndex
(
index_t
block_1d_id
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
M0
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
integer_divide_ceil
(
N
,
NPerBlock
);
if
(
M0
==
1
)
{
return
make_tuple
(
0
,
block_1d_id
);
}
else
if
(
N0
==
1
)
{
return
make_tuple
(
block_1d_id
,
0
);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const
auto
group_size
=
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
const
auto
big_group_num
=
GroupNum
-
(
group_size
*
GroupNum
-
M0
*
N0
);
const
auto
group_id_y
=
block_1d_id
/
GroupNum
;
const
auto
group_id_x
=
block_1d_id
-
group_id_y
*
GroupNum
;
const
auto
remap_block_1d_id
=
group_id_x
<=
big_group_num
?
group_id_x
*
group_size
+
group_id_y
:
group_id_x
*
group_size
+
big_group_num
-
group_id_x
+
group_id_y
;
const
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
index_t
idx_N0
=
remap_block_1d_id
-
idx_M0
*
N0
;
const
index_t
M0_tmp
=
M0
/
M01
;
const
index_t
M0_mod_M01
=
M0
-
M0_tmp
*
M01
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0_mod_M01
)
?
M01
:
M0_mod_M01
;
const
index_t
idx_M00
=
idx_M0
/
M01
;
const
index_t
idx_M01
=
idx_M0
-
idx_M00
*
M01
;
const
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const
index_t
N_out
=
idx_N0_M01_local
/
M01_adapt
;
const
index_t
idx_loc_mod_M01
=
idx_N0_M01_local
-
N_out
*
M01_adapt
;
return
make_tuple
(
idx_loc_mod_M01
+
idx_M00
*
M01
,
N_out
);
}
}
private:
index_t
M
;
index_t
N
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
ec959387
...
...
@@ -50,7 +50,6 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
GemmKernelArgs
=
typename
Base
::
GemmKernelArgs
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
static
constexpr
index_t
KBatch
=
1
;
struct
GemmTransKernelArg
{
...
...
@@ -65,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
};
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_grouped"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
size_t
{
...
...
@@ -78,8 +89,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
{
const
auto
dim3
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
dim3
.
x
*
dim3
.
y
*
1
;
const
auto
local_grid_size
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
local_grid_size
*
it_desc
.
k_batch
;
}
return
dim3
(
grid_size
,
1
,
1
);
}
...
...
@@ -107,8 +118,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
;
const
index_t
grid_size_grp
=
TilePartitioner
::
GridSize
(
M
,
N
)
*
gemm_descs
[
i
].
k_batch
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
...
...
@@ -124,7 +134,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
stride_a
,
stride_b
,
stride_c
,
KB
atch
};
gemm_descs
[
i
].
k_b
atch
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
...
...
@@ -139,8 +149,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
Run
(
const
GemmTransKernelArg
&
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
OffsetTile1DPartitioner
::
GetOffsetedTileIndex
(
kargs
.
block_start
,
kargs
.
group_karg
.
N
);
const
auto
[
iM
,
iN
]
=
OffsetTile1DPartitioner
::
GetOffsetedTileIndex
(
kargs
.
block_start
,
kargs
.
group_karg
.
M
,
kargs
.
group_karg
.
N
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
...
...
@@ -12,18 +13,23 @@ struct GemmPipelineAgBgCrImplBase
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
,
typename
DramTileWindowStep
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
SrcTileWindow
&
dram_tile_window
,
const
DramTileWindowStep
&
dram_tile_window_step
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
}
);
move_tile_window
(
dram_tile_window
,
dram_tile_window_step
);
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
...
...
@@ -35,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
DstBlockTile
&
dst_block_tile
,
const
SrcTileWindow
&
lds_tile_window
)
const
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
__restrict__
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
BDataType
*
__restrict__
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
...
...
@@ -60,19 +72,21 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
ALdsTensorView
&
a_lds_block_view
)
const
{
constexpr
bool
is_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
YPerTile
=
std
::
conditional_t
<
is_col_major
,
number
<
KPerBlock
>
,
number
<
MPerBlock
>>
;
using
XPerTile
=
std
::
conditional_t
<
is_col_major
,
number
<
MPerBlock
>
,
number
<
KPerBlock
>>
;
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
YPerTile
{},
XPerTile
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
...
...
@@ -86,18 +100,22 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BLdsTensorView
&
b_lds_block_view
)
const
{
constexpr
bool
is_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
using
YPerTile
=
std
::
conditional_t
<
is_row_major
,
number
<
KPerBlock
>
,
number
<
NPerBlock
>>
;
using
XPerTile
=
std
::
conditional_t
<
is_row_major
,
number
<
NPerBlock
>
,
number
<
KPerBlock
>>
;
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
YPerTile
{},
XPerTile
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag
mem_bgmem_creg_v1_default
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_
universal_
pipeline_ag
_bg_cr
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -20,6 +24,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
...
...
@@ -37,7 +43,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
...
...
@@ -62,27 +68,85 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrCompV3"
,
BlockSize
,
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
CK_TILE_HOST
static
std
::
string
Print
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
::
WarpGemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
WarpGemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
auto
str
=
std
::
stringstream
{};
str
<<
"A/B vector size: "
<<
GetVectorSizeA
()
<<
", "
<<
GetVectorSizeB
()
<<
"
\n
"
<<
"A/B LDS read/write width: "
<<
A_LDS_Read_Width
<<
", "
<<
B_LDS_Read_Width
<<
"
\n
"
<<
"A/B buffer load inst: "
<<
A_Buffer_Load_Inst_Num
<<
", "
<<
B_Buffer_Load_Inst_Num
<<
"
\n
"
<<
"A/B LDS write inst: "
<<
A_LDS_Write_Inst_Num
<<
", "
<<
B_LDS_Write_Inst_Num
<<
"
\n
"
<<
"A/B LDS read inst: "
<<
A_LDS_Read_Inst_Num
<<
", "
<<
B_LDS_Read_Inst_Num
<<
"
\n
"
<<
"C MFMA inst: "
<<
C_MFMA_Inst_Num
<<
"
\n
"
<<
"KPack: "
<<
BlockGemm
::
Traits
::
KPack
<<
"
\n
"
<<
"PrefetchStages: "
<<
PrefetchStages
<<
"
\n
"
;
return
str
.
str
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
...
...
@@ -96,29 +160,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I0
{})
;
constexpr
index_t
NPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I1
{})
;
constexpr
index_t
KPerXDL
=
BlockGemm
Shape
::
WarpTile
::
at
(
I2
{})
;
constexpr
index_t
MPerXDL
=
BlockGemm
::
Warp
Gemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
Warp
Gemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
MPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeA
()
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
NPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeB
()
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
...
...
@@ -248,11 +318,22 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -287,23 +368,51 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile
a_block_tile
;
BBlockTile
b_block_tile
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
...
...
@@ -318,11 +427,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrCompV4
{
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
Three
;
}
else
{
return
TailNumber
::
Two
;
}
}
};
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAgBgCrCompV4DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
GetVectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
GetVectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
GetVectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I1
{});
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I2
{});
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst
=
num_ds_read_inst_a
+
num_ds_read_inst_b
;
constexpr
auto
num_ds_write_inst
=
A_LDS_Write_Inst_Num
+
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst
=
A_Buffer_Load_Inst_Num
+
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_issue
=
num_buffer_load_inst
;
static_for
<
0
,
num_buffer_load_inst
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
C_MFMA_Inst_Num
/
num_issue
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"Data Type conflict on A and B matrix input data type."
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
////////////// global window & register /////////////////
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window_linear
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window_linear
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A register tile for global load
constexpr
auto
ABlockTileDistr
=
a_copy_dram_window
.
get_tile_distribution
();
constexpr
auto
BBlockTileDistr
=
b_copy_dram_window
.
get_tile_distribution
();
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
));
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// global prefetch 0
// global read 0
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
////////////// LDS desc, window & register /////////////////
auto
&&
[
a_lds_block0
,
b_lds_block0
]
=
Base
::
GetABLdsTensorViews
(
p_smem_0
);
auto
&&
[
a_lds_block1
,
b_lds_block1
]
=
Base
::
GetABLdsTensorViews
(
p_smem_1
);
auto
a_copy_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
a_copy_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_copy_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_copy_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
// global read 1
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_sync_lds
();
constexpr
auto
ALdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeABlockDistributionEncode
())){};
constexpr
auto
BLdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeBBlockDistributionEncode
())){};
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile0
;
BLdsTile
b_block_tile1
;
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
if
(
HasHotLoop
)
{
// minus 2 because we have ping-pong double buffer.
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
);
do
{
// ping
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// pong
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
iCounter
-=
2
;
}
while
(
iCounter
>
1
);
}
// tail 3
if
(
TailNum
==
TailNumber
::
Three
)
{
// 3
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
8
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem_0
,
p_smem_1
);
}
public:
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem_0
,
p_smem_1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace
ck_tile
{
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct
GemmPipelineAgBgCrCompV4DefaultPolicy
:
public
UniversalGemmBasePolicy
<
GemmPipelineAgBgCrCompV4DefaultPolicy
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
kMPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
kMPerBlock
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
>
{}
/
KPack
,
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
kNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -20,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
@@ -88,7 +91,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
...
@@ -113,19 +116,31 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrMe"
,
concat
(
'x'
,
MPerBlock
,
NPerBlock
,
KPerBlock
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -133,8 +148,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
@@ -165,11 +178,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -213,25 +237,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -247,19 +305,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -275,12 +359,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
@@ -352,11 +456,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -400,25 +515,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -432,19 +580,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -457,12 +631,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <sstream>
#include "ck_tile/core.hpp"
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -31,32 +32,36 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Problem
::
VectorSizeA
;
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Problem
::
VectorSizeB
;
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Problem
::
VectorSizeC
;
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
static
constexpr
index_t
kLdsAlignmentInBytes
=
16
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV1"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
BlockSize
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static
constexpr
bool
DoubleSmemBuffer
=
false
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
@@ -86,8 +91,9 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
kLdsAlignmentInBytes
)
*
kLdsAlignmentInBytes
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
...
...
@@ -150,7 +156,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledARegBlockD
istribution
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
...
...
@@ -164,7 +170,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledBRegBlockD
istribution
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
...
...
@@ -201,7 +207,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledBRegBlockD
istribution
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -16,39 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif
1
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
...
...
@@ -58,7 +25,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
...
...
@@ -127,87 +93,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
return
Problem
::
VectorLoadSize
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
return
Problem
::
VectorLoadSize
;
}
#elif 1
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
a_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc_m_k
;
}
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
...
...
@@ -273,7 +166,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
...
...
@@ -394,7 +286,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockD
escriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockD
istribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
@@ -442,11 +334,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockD
escriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockD
istribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
Row
Major
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
Column
Major
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -489,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
...
...
@@ -503,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -25,6 +26,15 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV2"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
kBlockSize
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
...
...
@@ -36,8 +46,6 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment