Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b616b254
Commit
b616b254
authored
Dec 05, 2024
by
letaoqin
Browse files
add debuging code and format
parent
2baf9422
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
185 additions
and
148 deletions
+185
-148
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp
..._fused_moe_general/instances/fused_moegemm_api_traits.hpp
+1
-1
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+146
-112
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+2
-2
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+10
-9
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+21
-19
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+4
-4
No files found.
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp
View file @
b616b254
...
...
@@ -45,7 +45,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
sequence
<
1
,
1
,
4
>
;
//ck_tile::remove_cvref_t<WarpPerBlock_>;
using
WarpPerBlock_1
=
ck_tile
::
sequence
<
1
,
1
,
4
>
;
//
ck_tile::remove_cvref_t<WarpPerBlock_>;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
...
...
example/ck_tile/17_fused_moe_general/main.cpp
View file @
b616b254
...
...
@@ -83,13 +83,43 @@ void topid_unique_gen(
host_tensor
[
i
]
=
current_v
;
}
}
template
<
typename
IndexType
>
void
output_matrix_2d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
m
,
int
n
)
{
std
::
cout
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
std
::
cout
<<
"Line "
<<
i
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
template
<
typename
IndexType
>
void
output_matrix_3d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
M
,
int
N
,
int
J
)
{
std
::
cout
<<
std
::
endl
;
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
std
::
cout
<<
"experts: "
<<
m
<<
" Line: "
<<
n
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
J
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
m
,
n
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"
5
"
,
"topk"
)
.
insert
(
"k"
,
"
2
"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
...
...
@@ -112,7 +142,7 @@ auto create_args(int argc, char* argv[])
"0"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"init"
,
"
2
"
,
"
1
"
,
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"normalized(slow)"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
...
...
@@ -176,9 +206,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
return
base_str
;
}();
auto
api_str
=
[
&
]()
{
return
std
::
string
(
"moeg"
);
}();
auto
api_str
=
[
&
]()
{
return
std
::
string
(
"moeg"
);
}();
auto
stride_str
=
[
&
]()
{
if
(
stride
==
hidden_size
)
...
...
@@ -245,7 +273,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
,
seed
,
true
}(
topk_weight_host
);
}
else
if
(
init
==
2
)
...
...
@@ -333,116 +361,122 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
static_cast
<
double
>
(
ms
)
*
1e-3
)
/
1e12
;
};
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// output_matrix_2d(a_host, tokens, hidden_size);
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
// std::cout << topk_weight_host << std::endl;
// std::cout << sorted_weight_host << std::endl;
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
);
// manually clear output buffer for atomic
o_buf
.
SetZero
();
//
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
);
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
);
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
fused_moegemm_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
prec_st
,
prec_sw
,
prec_sq
,
prec_kw
,
block_m
,
gate_only
,
fused_quant
};
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_perm_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
shared_intermediate_size_0
,
tokens
,
experts
,
topk
,
stride
,
max_num_tokens_padded
};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
cal_tflops
(
ave_time
)
<<
" tflops, "
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
block_m
);
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
);
// manually clear output buffer for atomic
o_buf
.
SetZero
();
//
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
);
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
);
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
fused_moegemm_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
prec_st
,
prec_sw
,
prec_sq
,
prec_kw
,
block_m
,
gate_only
,
fused_quant
};
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
g_perm_buf
.
GetDeviceBuffer
(),
d_perm_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
shared_intermediate_size_0
,
tokens
,
experts
,
topk
,
stride
,
max_num_tokens_padded
};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
cal_tflops
(
ave_time
)
<<
" tflops, "
<<
cal_tbps
(
ave_time
)
<<
" TB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
ck_tile
::
reference_fused_moe
<
AccDataType
,
ck_tile
::
element_wise
::
Gelu
>
(
a_host
,
g_host
,
d_host
,
sa_host
,
sg_host
,
sd_host
,
sy_host
,
o_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
,
topk_ids_host
,
block_m
,
tokens
,
experts
,
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
hidden_size
,
shared_intermediate_size_0
,
topk
,
gate_only
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
pass
&=
ck_tile
::
check_err
(
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
b616b254
...
...
@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
//constexpr index_t block_m = BlockShape::Block_M0;
//
constexpr index_t block_m = BlockShape::Block_M0;
int
max_num_tokens_padded
=
hargs
.
max_num_tokens_padded
;
//hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
//
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
b616b254
...
...
@@ -116,7 +116,7 @@ struct FusedMoeGemmHostArgs
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
b616b254
...
...
@@ -124,9 +124,9 @@ struct FusedMoeGemmPipeline_General
index_t
hidden_size
,
index_t
intermediate_size
)
{
ignore
=
d_window_
;
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
ignore
=
d_window_
;
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
...
...
@@ -191,12 +191,13 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
// relu
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
#if 0
#if 1
PrintMem
(
s_acc
);
#endif
// relu
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
// move sacc to LDS
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
...
...
@@ -238,7 +239,7 @@ struct FusedMoeGemmPipeline_General
index_t
iCounter1
=
n1_loops
-
1
;
while
(
iCounter1
>
0
)
{
clear_tile
(
o_acc
);
clear_tile
(
o_acc
);
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
block_sync_lds
();
...
...
@@ -253,7 +254,7 @@ struct FusedMoeGemmPipeline_General
}
// tail
{
clear_tile
(
o_acc
);
clear_tile
(
o_acc
);
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
b616b254
...
...
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
{
using
WG
=
decltype
(
GetWarpGemm0
<
Problem
>
());
using
S_
=
typename
Problem
::
BlockShape
;
static_assert
(
S_
::
WarpPerBlock_N0
==
4
);
static_assert
(
S_
::
WarpPerBlock_N0
==
4
);
constexpr
auto
g_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M0
>
,
tuple
<
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>
,
sequence
<
S_
::
Repeat_K0
>>
,
...
...
@@ -240,13 +240,14 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{};
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
...
...
@@ -260,13 +261,14 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
d_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{};
constexpr
auto
d_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{};
constexpr
auto
d_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
d_outer_dstr_enc
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
...
...
@@ -356,8 +358,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
1
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
wg_ctrl
>
,
...
...
@@ -396,8 +398,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
1
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
8
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
wg_ctrl
>
,
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
b616b254
...
...
@@ -52,16 +52,16 @@ struct BlockGemmARegBRegCRegV2
// M->N Warp
// constexpr auto a_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<NWarp>,
// tuple<sequence<MIterPerWarp, MWarp>,
sequence<KIterPerWarp>>,
// tuple<sequence<1, 0>>,
// tuple<sequence<MIterPerWarp, MWarp>,
//
sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto b_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>,
sequence<KIterPerWarp>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<NIterPerWarp, NWarp>,
//
sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment