Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
43de07ba
Commit
43de07ba
authored
Feb 13, 2025
by
Jakub Piasecki
Browse files
resolved conflicts
parents
d58f09a0
0e5e29c4
Changes
105
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
127 deletions
+305
-127
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+4
-1
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+11
-5
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+14
-3
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+0
-1
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+42
-6
example/ck_tile/15_fused_moe/fused_moe.hpp
example/ck_tile/15_fused_moe/fused_moe.hpp
+11
-8
example/ck_tile/15_fused_moe/fused_moesorting.hpp
example/ck_tile/15_fused_moe/fused_moesorting.hpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+58
-50
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+35
-25
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+25
-7
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
..._batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
+19
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+10
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
...pu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
+18
-4
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
43de07ba
...
@@ -4,4 +4,7 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
...
@@ -4,4 +4,7 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_dependencies
(
examples_ck_tile tile_example_gemm_basic
)
add_dependencies
(
examples_ck_tile tile_example_gemm_basic
)
rocm_install
(
TARGETS tile_example_gemm_basic COMPONENT examples_ck_tile
)
rocm_install
(
TARGETS tile_example_gemm_basic COMPONENT examples_ck_tile
)
add_dependencies
(
examples_ck_tile tile_example_gemm_universal
)
add_dependencies
(
examples_ck_tile tile_example_gemm_universal
)
rocm_install
(
TARGETS tile_example_gemm_universal COMPONENT examples_ck_tile
)
rocm_install
(
TARGETS tile_example_gemm_universal COMPONENT examples_ck_tile
)
\ No newline at end of file
target_compile_options
(
tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
43de07ba
...
@@ -11,21 +11,26 @@
...
@@ -11,21 +11,26 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_COMPUTE
_V3
1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
_V3
#endif
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
COMPUTE
)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_
MEMORY
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
#endif
...
@@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[])
...
@@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
.
insert
(
"split_k"
,
"1"
,
"splitK value"
)
.
insert
(
"init"
,
"0"
,
"0:random, 1:linear, 2:constant(1)"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
43de07ba
...
@@ -110,6 +110,7 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -110,6 +110,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
init_method
=
arg_parser
.
get_int
(
"init"
);
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
...
@@ -122,9 +123,19 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -122,9 +123,19 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
// TODO: add different init types
if
(
init_method
==
0
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
else
if
(
init_method
==
1
)
{
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1
)}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1
)}(
b_k_n
);
}
else
{
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
}
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
...
...
example/ck_tile/03_gemm/script/benchmark_basic.sh
View file @
43de07ba
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
43de07ba
...
@@ -34,8 +34,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -34,8 +34,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
@@ -48,6 +50,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -48,6 +50,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
#endif
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
...
@@ -70,8 +90,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -70,8 +90,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
...
@@ -99,8 +125,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -99,8 +125,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>
;
tail_number_v
>
;
using
GemmPipeline
=
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CDataType
,
...
@@ -140,7 +165,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -140,7 +165,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
...
@@ -215,6 +240,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -215,6 +240,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
else
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
#endif
#endif
}
}
else
else
...
...
example/ck_tile/15_fused_moe/fused_moe.hpp
View file @
43de07ba
...
@@ -8,14 +8,15 @@
...
@@ -8,14 +8,15 @@
struct
fused_moe_args
struct
fused_moe_args
{
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
local_expert_mask_ptr
;
// [e], local_expert_mask_ptr for EP
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
int
activation
;
// 0:gelu, 1:silu
int
activation
;
// 0:gelu, 1:silu
int
gate_only
;
// 0:g1u0, 1:g1u1
int
gate_only
;
// 0:g1u0, 1:g1u1
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool
local_expert_masking
;
// if mask experts as local expert
};
};
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/fused_moesorting.hpp
View file @
43de07ba
...
@@ -10,7 +10,8 @@
...
@@ -10,7 +10,8 @@
struct
fused_moesorting_trait
struct
fused_moesorting_trait
{
{
std
::
string
index_type
;
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
};
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
View file @
43de07ba
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return
1
;
return
1
;
}();
}();
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
};
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
,
t
.
local_expert_masking
};
auto
a0
=
fused_moesorting_args
{
auto
a0
=
fused_moesorting_args
{
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
local_expert_mask_ptr
,
// const void* p_local_expert_mask;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
43de07ba
...
@@ -24,20 +24,63 @@
...
@@ -24,20 +24,63 @@
return ave_time;
return ave_time;
#else
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
using ms_problem = \
constexpr bool sub_token_onshot = sub_token_onshot_; \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
constexpr bool local_expert_masking = local_expert_masking_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
auto kargs = kernel::MakeKargs(a); \
ms_weight_type, \
const dim3 grids = kernel::GridSize(a); \
sub_token_tile, \
const dim3 blocks = kernel::BlockSize(a); \
sub_token_onshot, \
const auto lds_bytes = kernel::GetSmemSize(a); \
local_expert_masking>; \
float ave_time = ck_tile::launch_kernel( \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#if !MOE_SORTING_USE_EX_KERNEL
...
@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
...
@@ -116,45 +159,10 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto
sub_token_
=
r_
-
2
;
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
(
void
)
c_
;
if
(
is_sub_token_onshot
)
{
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
true
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
true
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
true
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
true
);
}
}
else
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
false
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
false
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
false
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
false
);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
#endif
}
}
...
...
example/ck_tile/15_fused_moe/main.cpp
View file @
43de07ba
...
@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
hidden_size
;
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
bool
local_expert_masking
=
false
;
// TODO...
// w0 (Gate+Up or Gate only, N size)
// w0 (Gate+Up or Gate only, N size)
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
...
@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
IndexDataType
>
local_expert_mask_host
({
experts
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
...
@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
local_expert_mask_buf
(
local_expert_mask_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
...
@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m
,
block_m
,
activation
,
activation
,
gate_only
,
gate_only
,
fused_quant
};
fused_quant
,
local_expert_masking
};
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
...
@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
local_expert_masking
?
local_expert_mask_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
...
@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_ids_host
,
topk_weight_host
,
topk_weight_host
,
local_expert_mask_host
,
sorted_token_ids_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
,
local_expert_masking
);
if
(
activation
==
0
)
if
(
activation
==
0
)
{
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
...
@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_ids_host
,
topk_weight_host
,
topk_weight_host
,
local_expert_mask_host
,
sorted_token_ids_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
,
local_expert_masking
);
// done, preparing GPU buffer
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
);
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
43de07ba
...
@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
43de07ba
...
@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
43de07ba
...
@@ -361,10 +361,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -361,10 +361,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
constexpr
bool
is_single_rate_mfma
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
;
((
is_same
<
A0B0B1DataType
,
half_t
>::
value
||
is_same
<
A0B0B1DataType
,
bhalf_t
>::
value
)
&&
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
math
::
lcm
(
A0K1
,
B0K1
)
<=
4
)
constexpr
auto
N5
=
mfma
.
group_size
;
?
true
:
false
;
constexpr
auto
mfma
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
A0B0B1DataType
,
is_single_rate_mfma
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
...
@@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
constexpr
auto
lcm_A0K1_B0K1
=
math
::
lcm
(
A0K1
,
B0K1
);
math
::
lcm
(
A0K1
,
B0K1
),
constexpr
bool
is_single_rate_mfma
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
((
is_same
<
A0B0B1DataType
,
half_t
>::
value
||
is_same
<
A0B0B1DataType
,
bhalf_t
>::
value
)
&&
lcm_A0K1_B0K1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
lcm_A0K1_B0K1
,
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
A0B0B1DataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm0
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm0
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
43de07ba
...
@@ -343,10 +343,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
...
@@ -343,10 +343,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
bool
is_single_rate_mfma
=
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
constexpr
auto
N4
=
mfma
.
num_input_blks
;
math
::
lcm
(
AK1
,
BK1
)
<=
4
)
constexpr
auto
N5
=
mfma
.
group_size
;
?
true
:
false
;
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
...
@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
...
@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
43de07ba
...
@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
43de07ba
...
@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
43de07ba
...
@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
43de07ba
...
@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
constexpr
bool
is_single_rate_mfma
=
MfmaSelector
<
AComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
((
is_same
<
AComputeType
,
half_t
>::
value
||
is_same
<
AComputeType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
AComputeType
,
MPerXdl
,
NPerXdl
,
AComputeType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
43de07ba
...
@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
View file @
43de07ba
...
@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
...
@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
...
@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment