Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3b945fc9
Unverified
Commit
3b945fc9
authored
Dec 23, 2024
by
M.Emin Ozturk
Committed by
GitHub
Dec 23, 2024
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
6ef3acec
3d15f364
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
831 additions
and
118 deletions
+831
-118
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+0
-2
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+210
-37
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
...de/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+9
-4
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
...ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
+29
-15
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
+29
-15
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+16
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+259
-44
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+274
-0
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+4
-0
profiler/include/profiler/profile_grouped_gemm_impl.hpp
profiler/include/profiler/profile_grouped_gemm_impl.hpp
+1
-1
No files found.
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
3b945fc9
...
...
@@ -43,8 +43,6 @@ struct TileFmhaShape
static
constexpr
index_t
NumWarps
=
max
(
NumGemm0Warps
,
NumGemm1Warps
);
static_assert
(
std
::
is_same_v
<
Gemm0WarpTile
,
Gemm1WarpTile
>
);
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
3b945fc9
...
...
@@ -130,7 +130,8 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
const
auto
blocks
=
BlockSize
(
h
);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return
((
blocks
.
x
+
1
)
*
(
h
.
num_experts
+
1
)
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
...
@@ -154,6 +155,75 @@ struct MoeSortingKernel
return
k
;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
template
<
typename
data_t
,
int
wave_size
>
__device__
inline
void
wave_cumsum
(
data_t
&
thread_data
)
const
{
// wave_size must be power of 2
constexpr
int
row_mask
=
0xf
;
constexpr
int
bank_mask
=
0xf
;
constexpr
bool
bound_ctrl
=
true
;
// ! out-of-bound is zero !
auto
reduce_op
=
[
&
](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
if
constexpr
(
wave_size
>
1
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x111
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:1
}
if
constexpr
(
wave_size
>
2
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x112
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:2
}
if
constexpr
(
wave_size
>
4
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x114
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:4
}
if
constexpr
(
wave_size
>
8
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x118
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:8
}
if
constexpr
(
wave_size
>
16
)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(((
__lane_id
()
&
0x30
)
-
1
)
<<
2
,
__builtin_bit_cast
(
int
,
thread_data
));
v_remote_tmp
=
__lane_id
()
>=
16
?
v_remote_tmp
:
0
;
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
v_remote_tmp
));
}
if
constexpr
(
wave_size
>
32
)
{
// lane-id 48...63->31
int
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(((
__lane_id
()
&
0x30
)
-
17
)
<<
2
,
__builtin_bit_cast
(
int
,
thread_data
));
v_remote_tmp
=
__lane_id
()
>=
32
?
v_remote_tmp
:
0
;
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
v_remote_tmp
));
}
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
...
...
@@ -187,48 +257,124 @@ struct MoeSortingKernel
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
#if 1
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
,
0
,
tid
)]
=
0
;
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
++
i
)
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
index_t
local_c
[
8
];
index_t
prev_c
=
0
;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
{
tokens_cnts
[
calc_index
(
num_experts
,
i
,
tid
)]
+=
tokens_cnts
[
calc_index
(
num_experts
,
i
-
1
,
tid
)];
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
+=
prev_c
;
local_c
[
1
]
+=
local_c
[
0
];
local_c
[
2
]
+=
local_c
[
1
];
local_c
[
3
]
+=
local_c
[
2
];
local_c
[
4
]
+=
local_c
[
3
];
local_c
[
5
]
+=
local_c
[
4
];
local_c
[
6
]
+=
local_c
[
5
];
local_c
[
7
]
+=
local_c
[
6
];
prev_c
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
}
}
// __syncthreads();
if
(
tid
==
0
)
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
if
(
tid
<
num_experts
)
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
index_t
local_c
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
wave_cumsum
<
int
,
64
>
(
local_c
[
j
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
}
}
}
#endif
__syncthreads
();
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
==
0
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
padded_tokens_per_expert
=
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
int
local_cumsum
=
padded_tokens_per_expert
;
wave_cumsum
<
int
,
64
>
(
local_cumsum
);
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
*
p_total_tokens_post_pad
=
local_cumsum
;
}
if
(
tid
<
num_experts
)
{
cumsum
[
tid
+
1
]
=
local_cumsum
;
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size_mdiv
.
divisor
)
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
}
...
...
@@ -238,8 +384,8 @@ struct MoeSortingKernel
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
rank_post_pad
=
local_cnt
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
...
...
@@ -247,27 +393,54 @@ struct MoeSortingKernel
#else
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
}
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
index_t
expert_end
=
cumsum
[
tid
+
1
];
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
}
}
}
else
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
// TODO: only support expert-tile like 8, 16, 32
static
constexpr
index_t
experts_per_wave
=
warpSize
/
Problem
::
ExpertTile
;
{
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_end
=
cumsum
[
eid
+
1
];
if
(
eid
<
num_experts
)
{
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
+=
experts_per_wave
;
}
}
}
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
...
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
View file @
3b945fc9
...
...
@@ -9,15 +9,20 @@
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblem
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
View file @
3b945fc9
...
...
@@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
const
index_t
iNWarp
=
0
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
...
...
@@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
m
ake
_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
M
ake
ABlockTileDistribution
()
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
...
...
@@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1
});
}
template
<
index_t
MPerBlock
=
BlockGemmShape
::
kM
,
index_t
KPerBlock
=
BlockGemmShape
::
kK
>
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockTileDistribution
()
{
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_static_tile_distribution
(
a_block_dstr_encode
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
View file @
3b945fc9
...
...
@@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
...
...
@@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
m
ake
_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
M
ake
ABlockTileDistribution
()
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
...
...
@@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2
});
}
template
<
index_t
MPerBlock
=
BlockGemmShape
::
kM
,
index_t
KPerBlock
=
BlockGemmShape
::
kK
>
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockTileDistribution
()
{
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_static_tile_distribution
(
a_block_dstr_encode
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
3b945fc9
...
...
@@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaF16F16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
...
...
@@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M4N64K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
using
WarpGemmMfmaBf16Bf16F32M64N4K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
<
WGAttrCtlEnum
::
Default_
>
,
4
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
3b945fc9
...
...
@@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
"Multi-block on both M & N directions is not supported"
);
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
...
...
@@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
static_assert
(
Impl
::
kAMBlock
==
1
||
Impl
::
kBNBlock
==
1
,
"Multi-block on both M & N directions is not supported"
);
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_awarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
,
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// each N blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kAMBlock
>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
CK_TILE_DEVICE
static
constexpr
auto
get_bwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
// each M blocks share the same data
return
tile_distribution_encoding
<
sequence
<
Impl
::
kBNBlock
>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
0
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
// single block to multi-block thread mapping
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMBlock
,
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
1
,
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
1
>>
,
sequence
<
2
>
,
sequence
<
1
>>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
get_cwarp_dstr_encoding
()
{
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
Impl
::
kAMBlock
==
1
&&
1
<
Impl
::
kBNBlock
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNBlock
*
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
else
if
constexpr
(
1
<
Impl
::
kAMBlock
&&
Impl
::
kBNBlock
==
1
)
{
return
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kAMBlock
*
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
{};
}
}
using
AWarpDstrEncoding
=
decltype
(
get_awarp_dstr_encoding
());
using
BWarpDstrEncoding
=
decltype
(
get_bwarp_dstr_encoding
());
using
CWarpDstrEncoding
=
decltype
(
get_cwarp_dstr_encoding
());
template
<
bool
post_nop_
=
false
>
// c_vec += a_vec * b_vec
...
...
@@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
...
...
@@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
static_assert
(
Impl
::
kAMBlock
==
1
&&
Impl
::
kBNBlock
==
1
,
"Multi-block WarpGemmAttributeMfmaImpl is not supported"
);
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
3b945fc9
...
...
@@ -78,6 +78,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -138,6 +141,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
...
...
@@ -182,6 +188,134 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplF16F16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4f16"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// Bf16
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
...
...
@@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
...
...
@@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
4
;
static
constexpr
index_t
kN
=
64
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
16
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
64
;
static
constexpr
index_t
kN
=
4
;
static
constexpr
index_t
kK
=
4
;
static
constexpr
index_t
kAMBlock
=
16
;
static
constexpr
index_t
kBNBlock
=
1
;
// we only write down single block (4 threads) thread mapping here
static
constexpr
index_t
kAMLane
=
4
;
static
constexpr
index_t
kBNLane
=
4
;
static
constexpr
index_t
kABKLane
=
1
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
1
;
static
constexpr
index_t
kCNLane
=
4
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_f32_4x4x4bf16_1k"
,
Ctrl
)
else
{
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// FP8
template
<
typename
AType_
,
typename
BType_
,
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
...
...
@@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
@@ -568,6 +839,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMBlock
=
1
;
static
constexpr
index_t
kBNBlock
=
1
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
3b945fc9
...
...
@@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
...
...
@@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
4
,
64
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M4N64K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
64
,
4
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M64N4K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
...
...
profiler/include/profiler/profile_grouped_gemm_impl.hpp
View file @
3b945fc9
...
...
@@ -77,7 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification,
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_host_results
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
ComputeDataTyp
e
max_abs_in_val
=
0.
f
;
doubl
e
max_abs_in_val
=
0.
f
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
a_m_k
.
push_back
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment