Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
abd2755a
Commit
abd2755a
authored
Jan 06, 2025
by
ThomasNing
Browse files
Merge branch 'develop' into moe_cross_reduce
parents
b74918bc
888317e6
Changes
166
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
489 additions
and
173 deletions
+489
-173
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+2
-4
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+4
-10
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+10
-9
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+12
-2
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+93
-67
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+2
-1
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+3
-3
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+4
-4
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+7
-13
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+34
-19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+34
-19
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+9
-4
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+2
-1
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+4
-0
include/ck/config.h.in
include/ck/config.h.in
+2
-2
include/ck/library/utility/host_tensor.hpp
include/ck/library/utility/host_tensor.hpp
+66
-12
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+30
-0
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
...block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
+167
-0
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
View file @
abd2755a
...
...
@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
...
...
@@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
abd2755a
...
...
@@ -96,9 +96,7 @@ using fmha_epilogue =
{F_spad}, {F_dvpad}>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
...
...
@@ -176,11 +174,7 @@ using fmha_epilogue =
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
...
...
@@ -261,7 +255,7 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype},
ck_tile::fp8_t
>) {{
if constexpr (std::is_same_v<{F_dtype},
FmhaFwdFp8
>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
...
...
@@ -614,7 +608,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
abd2755a
...
...
@@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
// NOTE: use gpu to do validation
ck_tile
::
naive_attention_fwd_traits
naive_t
;
naive_t
.
q_type
=
data_type
;
naive_t
.
k_type
=
data_type
;
naive_t
.
v_type
=
data_type
;
naive_t
.
o_type
=
data_type
;
naive_t
.
q_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
k_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
v_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
o_layout
=
o_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
variation
=
0
;
// TODO?
naive_t
.
q_type
=
data_type
;
naive_t
.
k_type
=
data_type
;
naive_t
.
v_type
=
data_type
;
naive_t
.
o_type
=
data_type
;
naive_t
.
q_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
k_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
v_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
o_layout
=
o_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
variation
=
0
;
// TODO?
naive_t
.
quant_algo
=
0
;
ck_tile
::
DeviceMem
o_naive_buf
(
o_host
.
get_element_space_size_in_bytes
());
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
abd2755a
...
...
@@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}();
dim3
grids
=
FmhaKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
if
constexpr
(
FmhaKernel
::
kIsGroupMode
)
{
dim3
grids
=
FmhaKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
,
args
.
seqlen_k_ptr
!=
nullptr
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
else
{
dim3
grids
=
FmhaKernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
max_seqlen_q
,
args
.
hdim_v
,
false
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
}
template
<
typename
Kernel
>
...
...
example/ck_tile/02_layernorm2d/generate.py
View file @
abd2755a
This diff is collapsed.
Click to expand it.
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
abd2755a
...
...
@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
9120
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
abd2755a
...
...
@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
...
...
@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
abd2755a
...
...
@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch
_size
=
arg_parser
.
get_int
(
"
b
"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
k
batch
=
arg_parser
.
get_int
(
"
split_k
"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
...
...
@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A
,
stride_B
,
stride_C
,
batch
_size
,
k
batch
,
n_warmup
,
n_repeat
);
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
abd2755a
...
...
@@ -22,7 +22,7 @@
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
...
...
@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
...
@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
abd2755a
...
...
@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
...
@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
MOE_SORTING_DISPATCH
(
4
);
}
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
abd2755a
...
...
@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE
-t
=
71
-e
=
11
-k
=
11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
\ No newline at end of file
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
abd2755a
...
...
@@ -3,18 +3,42 @@
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
...
@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
MOE_SORTING_DISPATCH
(
4
);
}
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
abd2755a
...
...
@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
abd2755a
...
...
@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
abd2755a
...
...
@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
ck_tile
::
index_t
kbatch
,
int
n_warmup
,
int
n_repeat
)
{
...
...
@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k_batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
batch_stride_B
=
arg_parser
.
get_int
(
"batch_stride_b"
);
ck_tile
::
index_t
batch_stride_C
=
arg_parser
.
get_int
(
"batch_stride_c"
);
ck_tile
::
index_t
batch_count
=
arg_parser
.
get_int
(
"batch_count"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
...
...
@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B
,
batch_stride_C
,
batch_count
,
kbatch
,
n_warmup
,
n_repeat
);
...
...
include/ck/config.h.in
View file @
abd2755a
...
...
@@ -115,8 +115,8 @@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef
D
CK_USE_OCP_FP8
#cmakedefine
D
CK_USE_OCP_FP8 @
D
CK_USE_OCP_FP8@
#ifndef CK_USE_OCP_FP8
#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
...
...
include/ck/library/utility/host_tensor.hpp
View file @
abd2755a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else
os
<<
delim
;
if
constexpr
(
std
::
is_same_v
<
T
,
ck
::
f8_t
>
||
std
::
is_same_v
<
T
,
ck
::
bf8_t
>
)
using
RangeType
=
ck
::
remove_cvref_t
<
decltype
(
v
)
>
;
if
constexpr
(
std
::
is_same_v
<
RangeType
,
ck
::
f8_t
>
||
std
::
is_same_v
<
RangeType
,
ck
::
bf8_t
>
||
std
::
is_same_v
<
RangeType
,
ck
::
bhalf_t
>
)
{
os
<<
ck
::
type_convert
<
float
>
(
v
);
}
else
if
constexpr
(
std
::
is_same_v
<
RangeType
,
ck
::
pk_i4_t
>
)
{
const
auto
packed_floats
=
ck
::
type_convert
<
ck
::
float2_t
>
(
v
);
const
ck
::
vector_type
<
float
,
2
>
vector_of_floats
{
packed_floats
};
os
<<
vector_of_floats
.
template
AsType
<
float
>()[
ck
::
Number
<
0
>
{}]
<<
delim
<<
vector_of_floats
.
template
AsType
<
float
>()[
ck
::
Number
<
1
>
{}];
}
else
{
os
<<
static_cast
<
T
>
(
v
);
...
...
@@ -266,18 +275,18 @@ struct Tensor
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
}
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
:
mDesc
(
lens
,
strides
),
mData
(
GetElementSpaceSize
())
{
}
template
<
typename
Lengths
>
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
}
...
...
@@ -287,7 +296,7 @@ struct Tensor
{
}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
...
...
@@ -322,7 +331,17 @@ struct Tensor
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
(
mDesc
.
GetElementSpaceSize
()
+
1
)
/
2
;
}
else
{
return
mDesc
.
GetElementSpaceSize
();
}
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
...
...
@@ -469,29 +488,64 @@ struct Tensor
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
;
}
else
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
}
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
...
...
include/ck/library/utility/host_tensor_generator.hpp
View file @
abd2755a
...
...
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
}
};
template
<
>
struct
GeneratorTensor_1
<
ck
::
pk_i4_t
>
{
int8_t
value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
t
=
value
+
8
;
ck
::
pk_i4_t
r
=
((
t
<<
4
)
+
t
)
&
0xff
;
return
r
;
}
};
template
<
typename
T
>
struct
GeneratorTensor_2
{
...
...
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
}
};
template
<
>
struct
GeneratorTensor_2
<
ck
::
pk_i4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
hi
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
int
lo
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
ck
::
pk_i4_t
r
=
((
hi
<<
4
)
+
lo
)
&
0xff
;
return
r
;
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
...
...
include/ck/tensor/static_tensor.hpp
View file @
abd2755a
...
...
@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
...
...
@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
0 → 100644
View file @
abd2755a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
constexpr
auto
BlockGemmPipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
return
BlockwiseGemmXdlops_pipeline_v1_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_v2_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
return
BlockwiseGemmXdlops_pipeline_v4_b_scale
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
return
BlockwiseGemmXdlops_pipeline_v5
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment