Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7dc420a8
Commit
7dc420a8
authored
Feb 12, 2025
by
ThomasNing
Browse files
Solve merge conflict and add the gtest for compv4
parents
884a2f7c
ef2b53a9
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
641 additions
and
32 deletions
+641
-32
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
+2
-1
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+8
-0
example/ck_tile/15_fused_moe/README.md
example/ck_tile/15_fused_moe/README.md
+1
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+74
-0
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+5
-2
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+1
-1
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+1
-1
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+354
-19
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+7
-4
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/concat.hpp
include/ck_tile/host/concat.hpp
+122
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-0
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+1
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/utils.hpp
include/ck_tile/ops/common/utils.hpp
+34
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
No files found.
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
View file @
7dc420a8
...
...
@@ -10,7 +10,8 @@
struct
moe_sorting_trait
{
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
7dc420a8
...
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
11
-e
=
256
-k
=
5
$EXE
-t
=
64
-e
=
455
-k
=
8
$EXE
-t
=
777
-e
=
802
-k
=
99
$EXE
-t
=
4097
-e
=
906
-k
=
51
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
13
-e
=
64
-k
=
3
-local_eid
=
4,5,6,7,8,9,10,11
$EXE
-t
=
99
-e
=
33
-k
=
9
-local_eid
=
6,10,11,15,19
$EXE
-t
=
80
-e
=
99
-k
=
10
-local_eid
=
0,8,12,33
$EXE
-t
=
11
-e
=
256
-k
=
5
-local_eid
=
99,110,129
example/ck_tile/15_fused_moe/README.md
View file @
7dc420a8
...
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
7dc420a8
...
...
@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
...
@@ -17,6 +23,24 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
...
...
@@ -38,11 +62,13 @@
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
#endif
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
...
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH
(
4
);
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
(
void
)
c_
;
if
(
is_sub_token_onshot
)
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
true
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
true
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
true
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
true
);
}
}
else
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
false
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
false
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
false
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
false
);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return
-
1
;
}
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
7dc420a8
...
...
@@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
7dc420a8
...
...
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
7dc420a8
...
...
@@ -118,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel
: "
<<
GroupedGemmKernel
::
GetName
()
<<
"
with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
7dc420a8
...
...
@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
}
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
7dc420a8
...
...
@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
...
...
@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
);
}
// polymorphic
...
...
@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
}
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
{
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
#endif
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
7dc420a8
...
...
@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// if workspace is not allocated
if
(
!
arg
.
p_workspace_
)
{
std
::
cerr
<<
"Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<<
std
::
endl
;
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<<
std
::
endl
;
}
return
false
;
}
if
(
!
ck
::
is_xdl_supported
())
...
...
include/ck_tile/host.hpp
View file @
7dc420a8
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
...
...
include/ck_tile/host/concat.hpp
0 → 100644
View file @
7dc420a8
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
T
>
struct
IsCharArray
:
std
::
false_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
typename
...
Ts
>
inline
constexpr
bool
AllConvertibleToStringView
=
((
std
::
is_convertible_v
<
Ts
,
std
::
string_view
>
||
IsCharArray
<
Ts
>::
value
||
std
::
is_same_v
<
Ts
,
char
>
)
&&
...);
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
(
oss
<<
...
<<
xs
);
return
oss
.
str
();
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
char
(
&
)[
N
])
noexcept
{
return
N
;
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
(
&
)[
N
])
noexcept
{
return
N
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
*
s
)
noexcept
{
const
char
*
end
=
s
;
while
(
*
end
++
!=
0
)
{}
return
end
-
s
-
1
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
&
)
noexcept
{
return
1
;
}
[[
nodiscard
]]
inline
std
::
size_t
getSize
(
const
std
::
string
&
s
)
noexcept
{
return
s
.
size
();
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
std
::
string_view
&
s
)
noexcept
{
return
s
.
size
();
}
template
<
typename
...
Ts
>
auto
concatInto
(
std
::
string
&
result
,
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
void
>
{
const
std
::
size_t
space
=
(
1
+
...
+
getSize
(
xs
));
result
.
reserve
(
result
.
size
()
+
space
);
((
result
+=
xs
),
...);
}
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
std
::
string
result
;
concatInto
(
result
,
xs
...);
return
result
;
}
// Function for types convertible to std::string_view
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
std
::
string
result
;
result
+=
first
;
((
result
+=
sep
,
result
+=
rest
),
...);
return
result
;
}
// Function for other types
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
oss
<<
first
;
((
oss
<<
sep
<<
rest
),
...);
return
oss
.
str
();
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
7dc420a8
...
...
@@ -14,12 +14,15 @@ namespace ck_tile {
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
IndexType
>&
local_expert_mask
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
const
index_t
unit_size
,
bool
local_expert_masking
,
bool
skip_experts_with_zero_token
=
true
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
// count number of unit-size slices in this expert
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
// count the tokens used in this expert
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
// TODO: above 2 buffer seems duplicated
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
...
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
int
curr_expert_id
=
0
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
if
(
local_expert_masking
)
{
if
(
local_expert_mask
(
e
)
==
0
)
continue
;
}
if
(
skip_experts_with_zero_token
)
{
if
(
expert_slice_idxs
[
e
]
==
0
)
{
curr_expert_id
++
;
continue
;
}
}
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
...
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
out_expert_id
[
s
]
=
curr_expert_id
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
curr_expert_id
++
;
}
unit_cnt
*=
unit_size
;
return
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
7dc420a8
...
...
@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose.hpp
View file @
7dc420a8
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common.hpp
View file @
7dc420a8
...
...
@@ -5,3 +5,4 @@
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common/utils.hpp
0 → 100644
View file @
7dc420a8
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// clang-format off
template
<
typename
T
>
struct
typeToStr
;
template
<
>
struct
typeToStr
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
typeToStr
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
typeToStr
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
typeToStr
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
typeToStr
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
typeToStr
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
template
<
typename
ADataType_
,
typename
BDataType_
>
std
::
string
gemm_prec_str
()
{
std
::
string
base_str
=
std
::
string
(
typeToStr
<
ADataType_
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType_
,
BDataType_
>
)
{
base_str
+=
"_"
+
std
::
string
(
typeToStr
<
BDataType_
>::
name
);
}
return
base_str
;
}
}
// namespace ck_tile
include/ck_tile/ops/elementwise.hpp
View file @
7dc420a8
...
...
@@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/epilogue.hpp
View file @
7dc420a8
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/flatmm.hpp
View file @
7dc420a8
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment