Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6cb91035
"...composable_kernel_rocm.git" did not exist on "ba6f79a75e65610871fd5139311817642292085c"
Commit
6cb91035
authored
Dec 04, 2024
by
letaoqin
Browse files
add fp16 to test
parent
4dd77195
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
330 additions
and
114 deletions
+330
-114
example/ck_tile/16_fused_moe_general/fused_moegemm.hpp
example/ck_tile/16_fused_moe_general/fused_moegemm.hpp
+16
-0
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
...tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp
...used_moe_general/instances/fused_moegemm_api_internal.hpp
+2
-2
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp
..._fused_moe_general/instances/fused_moegemm_api_traits.hpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
...16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
+1
-1
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+4
-4
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+57
-74
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+46
-31
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+202
-0
No files found.
example/ck_tile/16_fused_moe_general/fused_moegemm.hpp
View file @
6cb91035
...
...
@@ -13,6 +13,22 @@
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ST
,
SW
,
SQ
,
KW
>
{
using
ADataType
=
ck_tile
::
fp16_t
;
using
GDataType
=
ck_tile
::
fp16_t
;
using
DDataType
=
ck_tile
::
fp16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
fp16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
GScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
DScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
{
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
View file @
6cb91035
...
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
32
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp
View file @
6cb91035
...
...
@@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_
0
,
typename
Ts_
::
WarpTile_
0
>
;
typename
Ts_
::
WarpPerBlock_
1
,
typename
Ts_
::
WarpTile_
1
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp
View file @
6cb91035
...
...
@@ -49,7 +49,7 @@ struct fmoe_ // traits, ugly name, only used for internal
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_1
=
ck_tile
::
sequence
<
1
,
1
,
4
>
;
//
ck_tile::remove_cvref_t<WarpPerBlock_>;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
View file @
6cb91035
...
...
@@ -8,7 +8,7 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
fmoe_
<
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
32
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/16_fused_moe_general/main.cpp
View file @
6cb91035
...
...
@@ -252,8 +252,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
}(
topk_weight_host
);
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
//
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
//
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if
(
balance
)
...
...
@@ -287,7 +287,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// std::cout << num_sorted_tiles_host << std::endl;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
//
std::cout << topk_weight_host << std::endl;
std
::
cout
<<
topk_weight_host
<<
std
::
endl
;
// std::cout << sorted_weight_host << std::endl;
...
...
@@ -431,7 +431,7 @@ int main(int argc, char* argv[])
// no dynamic quant case
if
(
prec_i
==
"bf16"
&&
prec_w
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
float
,
float
,
float
,
float
>
(
return
run
<
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
float
,
float
,
float
,
float
>
(
arg_parser
)
?
0
:
-
2
;
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
6cb91035
...
...
@@ -88,6 +88,32 @@ struct FusedMoeGemmPipeline_General
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
static
void
PrintMem
(
T
&
tensor
)
{
constexpr
auto
spans
=
T
::
get_distributed_spans
();
int
counter
=
0
;
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idxn
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idxk
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idxn
,
idxk
);
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tensor
.
get_tile_distribution
(),
i_j_idx
);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
const
auto
row
=
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
tile_idx
.
at
(
number
<
1
>
{});
printf
(
"in G row is %d , col is %d, counter is %d, value is: %f"
"
\n
"
,
row
,
col
,
counter
,
ck_tile
::
type_convert
<
float
>
(
tensor
(
i_j_idx
)));
counter
=
counter
+
1
;
}
});
});
}
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
...
...
@@ -131,56 +157,13 @@ struct FusedMoeGemmPipeline_General
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
#if 0
{
// check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idk_0 = idxk.impl_.at(0);
printf("in A idm is %d , idk_ is %d , counter is %d, value is: %f \n",
idm_0,
idk_0,
counter,
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
}
});
});
}
PrintMem(a_dram_block);
#endif
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
#if 0
{
constexpr auto g_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(g_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(g_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx = get_x_indices_from_distributed_indices(
g_dram_block.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in G row is %d , col is %d, counter is %d, value is: %f"
" \n",
row,
col,
counter,
ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
}
});
});
}
PrintMem(g_dram_block);
#endif
clear_tile
(
s_acc
);
// initialize C
...
...
@@ -215,32 +198,8 @@ struct FusedMoeGemmPipeline_General
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
#if 1
{
constexpr
auto
a_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
int
counter
=
0
;
// a_spans[0] = 1;
sweep_tile_span
(
a_spans
[
number
<
0
>
{}],
[
&
](
auto
idxm
)
{
sweep_tile_span
(
a_spans
[
number
<
1
>
{}],
[
&
](
auto
idxn
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idxm
,
idxn
);
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
g_dram_block
.
get_tile_distribution
(),
i_j_idx
);
if
(
threadIdx
.
x
==
1
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
counter
=
counter
+
1
;
const
auto
row
=
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
tile_idx
.
at
(
number
<
1
>
{});
printf
(
"in c row is %d , col is %d, counter is %d, value is: "
"%f
\n
"
,
row
,
col
,
counter
,
ck_tile
::
type_convert
<
float
>
(
s_acc
(
i_j_idx
)));
}
});
});
}
#if 0
PrintMem(s_acc);
#endif
// move sacc to LDS
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -249,15 +208,30 @@ struct FusedMoeGemmPipeline_General
make_tile_window
(
bridge_lds_view
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// cast data to YDataType
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// //y_pre.get_thread_buffer()(i) = type_convert<YDataType>(s_acc.get_thread_buffer()[i]);
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf("soure value: %f to value: %f\n",
// s_acc.get_thread_buffer()[i],
// type_convert<float>(y_pre.get_thread_buffer()[i]));
// }
// });
#if 1
PrintMem
(
y_pre
);
#endif
// save to lds
store_tile
(
bridge_slds_win
,
y_pre
);
block_sync_lds
();
// gemm down
constexpr
auto
gemm_1
=
Policy
::
template
GetBlockGemm1
<
Problem
>();
using
S
accBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
S
accBlockTileType
{};
using
O
accBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
O
accBlockTileType
{};
// y data
auto
bridge_llds_win
=
make_tile_window
(
bridge_lds_view
,
...
...
@@ -265,6 +239,7 @@ struct FusedMoeGemmPipeline_General
{
0
,
0
},
Policy
::
template
MakeYTileDistribution
<
Problem
>());
auto
y
=
load_tile
(
bridge_llds_win
);
// d data
auto
d_global_to_dram_window
=
make_tile_window
(
d_window_
.
get_bottom_tensor_view
(),
...
...
@@ -278,6 +253,7 @@ struct FusedMoeGemmPipeline_General
index_t
iCounter1
=
n1_loops
-
1
;
while
(
iCounter1
>
0
)
{
clear_tile
(
o_acc
);
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
block_sync_lds
();
...
...
@@ -292,9 +268,16 @@ struct FusedMoeGemmPipeline_General
}
// tail
{
clear_tile
(
o_acc
);
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_window_
,
o
);
}
#if 0
PrintMem(o_acc);
#endif
// store_tile(o_window_, a_dram_block);
}
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
6cb91035
...
...
@@ -13,7 +13,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v
1
.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v
2
.hpp"
namespace
ck_tile
{
...
...
@@ -230,7 +230,28 @@ struct FusedMoeGemmPipelineGeneralPolicy
typename
S_
::
WarpPerBlock_1
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
// this is used as A matrix for 2nd gemm
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeYTileDistribution
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
y_block_dstr
=
make_static_tile_distribution
(
y_block_dstr_encode
);
return
y_block_dstr
;
}
template
<
typename
Problem
>
...
...
@@ -240,12 +261,12 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
d_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
1
>>
{};
constexpr
auto
d_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
d_outer_dstr_enc
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
...
...
@@ -326,7 +347,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
// TODO: this is ugly
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_avv
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
fp16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
wg_ctrl
>
,
1
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
{
...
...
@@ -358,7 +387,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_avv
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
fp16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
wg_ctrl
>
,
1
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
{
...
...
@@ -383,27 +420,5 @@ struct FusedMoeGemmPipelineGeneralPolicy
2
>>
{};
}
}
// this is used as A matrix for 2nd gemm
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeYTileDistribution
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
// TODO: all waves a along different N, but same M
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_N1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
y_block_dstr
=
make_static_tile_distribution
(
y_block_dstr_encode
);
return
y_block_dstr
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
0 → 100644
View file @
6cb91035
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV2
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
// M->N Warp
// constexpr auto a_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<NWarp>,
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto b_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
// constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
// check ABC-block-distribution
// static_assert(
// std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
// remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
// .get_static_tile_distribution_encoding())>>,
// "A distribution is wrong!");
// static_assert(
// std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
// remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
// .get_static_tile_distribution_encoding())>>,
// "B distribution is wrong!");
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A Block window
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B block tensor
BWarpTensor
b_warp_tensor
;
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor
,
b_block_tensor
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment