Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ef2b53a9
Commit
ef2b53a9
authored
Feb 12, 2025
by
ThomasNing
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into develop
parents
a9df4183
3c7fef7f
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
641 additions
and
33 deletions
+641
-33
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
+2
-1
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+8
-0
example/ck_tile/15_fused_moe/README.md
example/ck_tile/15_fused_moe/README.md
+1
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+74
-0
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+5
-2
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+1
-1
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+1
-1
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+354
-19
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+7
-4
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/concat.hpp
include/ck_tile/host/concat.hpp
+122
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-0
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+1
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/utils.hpp
include/ck_tile/ops/common/utils.hpp
+34
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
No files found.
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
View file @
ef2b53a9
...
...
@@ -10,7 +10,8 @@
struct
moe_sorting_trait
{
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
ef2b53a9
...
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
11
-e
=
256
-k
=
5
$EXE
-t
=
64
-e
=
455
-k
=
8
$EXE
-t
=
777
-e
=
802
-k
=
99
$EXE
-t
=
4097
-e
=
906
-k
=
51
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
13
-e
=
64
-k
=
3
-local_eid
=
4,5,6,7,8,9,10,11
$EXE
-t
=
99
-e
=
33
-k
=
9
-local_eid
=
6,10,11,15,19
$EXE
-t
=
80
-e
=
99
-k
=
10
-local_eid
=
0,8,12,33
$EXE
-t
=
11
-e
=
256
-k
=
5
-local_eid
=
99,110,129
example/ck_tile/15_fused_moe/README.md
View file @
ef2b53a9
...
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
ef2b53a9
...
...
@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
...
@@ -17,6 +23,24 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
...
...
@@ -38,11 +62,13 @@
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
#endif
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
...
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH
(
4
);
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
(
void
)
c_
;
if
(
is_sub_token_onshot
)
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
true
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
true
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
true
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
true
);
}
}
else
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
false
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
false
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
false
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
false
);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return
-
1
;
}
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
ef2b53a9
...
...
@@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
ef2b53a9
...
...
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
ef2b53a9
...
...
@@ -118,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel
: "
<<
GroupedGemmKernel
::
GetName
()
<<
"
with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
ef2b53a9
...
...
@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
}
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
ef2b53a9
...
...
@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
...
...
@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
);
}
// polymorphic
...
...
@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
}
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
{
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
#endif
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
ef2b53a9
...
...
@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// if workspace is not allocated
if
(
!
arg
.
p_workspace_
)
{
std
::
cerr
<<
"Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<<
std
::
endl
;
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<<
std
::
endl
;
}
return
false
;
}
if
(
!
ck
::
is_xdl_supported
())
...
...
include/ck_tile/core.hpp
View file @
ef2b53a9
...
...
@@ -27,12 +27,12 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
...
...
include/ck_tile/host.hpp
View file @
ef2b53a9
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
...
...
include/ck_tile/host/concat.hpp
0 → 100644
View file @
ef2b53a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
T
>
struct
IsCharArray
:
std
::
false_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
IsCharArray
<
const
char
(
&
)[
N
]
>
:
std
::
true_type
{
};
template
<
typename
...
Ts
>
inline
constexpr
bool
AllConvertibleToStringView
=
((
std
::
is_convertible_v
<
Ts
,
std
::
string_view
>
||
IsCharArray
<
Ts
>::
value
||
std
::
is_same_v
<
Ts
,
char
>
)
&&
...);
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
(
oss
<<
...
<<
xs
);
return
oss
.
str
();
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
char
(
&
)[
N
])
noexcept
{
return
N
;
}
template
<
std
::
size_t
N
>
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
(
&
)[
N
])
noexcept
{
return
N
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
*
s
)
noexcept
{
const
char
*
end
=
s
;
while
(
*
end
++
!=
0
)
{}
return
end
-
s
-
1
;
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
char
&
)
noexcept
{
return
1
;
}
[[
nodiscard
]]
inline
std
::
size_t
getSize
(
const
std
::
string
&
s
)
noexcept
{
return
s
.
size
();
}
[[
nodiscard
]]
constexpr
inline
std
::
size_t
getSize
(
const
std
::
string_view
&
s
)
noexcept
{
return
s
.
size
();
}
template
<
typename
...
Ts
>
auto
concatInto
(
std
::
string
&
result
,
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
void
>
{
const
std
::
size_t
space
=
(
1
+
...
+
getSize
(
xs
));
result
.
reserve
(
result
.
size
()
+
space
);
((
result
+=
xs
),
...);
}
template
<
typename
...
Ts
>
[[
nodiscard
]]
auto
concat
(
const
Ts
&
...
xs
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
Ts
...
>
,
std
::
string
>
{
std
::
string
result
;
concatInto
(
result
,
xs
...);
return
result
;
}
// Function for types convertible to std::string_view
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
std
::
string
result
;
result
+=
first
;
((
result
+=
sep
,
result
+=
rest
),
...);
return
result
;
}
// Function for other types
template
<
typename
Sep
,
typename
First
,
typename
...
Rest
>
[[
nodiscard
]]
auto
concat
(
Sep
sep
,
const
First
&
first
,
const
Rest
&
...
rest
)
->
std
::
enable_if_t
<!
AllConvertibleToStringView
<
First
,
Rest
...
>
,
std
::
string
>
{
using
::
operator
<<
;
thread_local
std
::
ostringstream
oss
;
oss
.
str
(
""
);
oss
<<
first
;
((
oss
<<
sep
<<
rest
),
...);
return
oss
.
str
();
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
ef2b53a9
...
...
@@ -14,12 +14,15 @@ namespace ck_tile {
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
IndexType
>&
local_expert_mask
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
const
index_t
unit_size
,
bool
local_expert_masking
,
bool
skip_experts_with_zero_token
=
true
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
// count number of unit-size slices in this expert
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
// count the tokens used in this expert
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
// TODO: above 2 buffer seems duplicated
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
...
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
int
curr_expert_id
=
0
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
if
(
local_expert_masking
)
{
if
(
local_expert_mask
(
e
)
==
0
)
continue
;
}
if
(
skip_experts_with_zero_token
)
{
if
(
expert_slice_idxs
[
e
]
==
0
)
{
curr_expert_id
++
;
continue
;
}
}
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
...
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
out_expert_id
[
s
]
=
curr_expert_id
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
curr_expert_id
++
;
}
unit_cnt
*=
unit_size
;
return
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
ef2b53a9
...
...
@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose.hpp
View file @
ef2b53a9
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common.hpp
View file @
ef2b53a9
...
...
@@ -5,3 +5,4 @@
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common/utils.hpp
0 → 100644
View file @
ef2b53a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// clang-format off
template
<
typename
T
>
struct
typeToStr
;
template
<
>
struct
typeToStr
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
typeToStr
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
typeToStr
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
typeToStr
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
typeToStr
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
typeToStr
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
template
<
typename
ADataType_
,
typename
BDataType_
>
std
::
string
gemm_prec_str
()
{
std
::
string
base_str
=
std
::
string
(
typeToStr
<
ADataType_
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType_
,
BDataType_
>
)
{
base_str
+=
"_"
+
std
::
string
(
typeToStr
<
BDataType_
>::
name
);
}
return
base_str
;
}
}
// namespace ck_tile
include/ck_tile/ops/elementwise.hpp
View file @
ef2b53a9
...
...
@@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/epilogue.hpp
View file @
ef2b53a9
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment