Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b8d11559
Unverified
Commit
b8d11559
authored
Feb 17, 2025
by
amd-khushbu
Committed by
GitHub
Feb 17, 2025
Browse files
Merge branch 'develop' into ck_profiler_m_instances
parents
7f3fe4e7
3b230208
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
715 additions
and
94 deletions
+715
-94
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+0
-1
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+46
-19
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+57
-6
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+82
-0
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
+2
-1
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+8
-0
example/ck_tile/15_fused_moe/README.md
example/ck_tile/15_fused_moe/README.md
+1
-1
example/ck_tile/15_fused_moe/fused_moe.hpp
example/ck_tile/15_fused_moe/fused_moe.hpp
+11
-8
example/ck_tile/15_fused_moe/fused_moesorting.hpp
example/ck_tile/15_fused_moe/fused_moesorting.hpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+82
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+35
-25
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+5
-2
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+1
-1
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+1
-1
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+354
-19
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+7
-4
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+9
-1
No files found.
example/ck_tile/03_gemm/script/benchmark_basic.sh
View file @
b8d11559
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
b8d11559
...
@@ -34,8 +34,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -34,8 +34,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
@@ -48,6 +50,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -48,6 +50,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
#endif
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
...
@@ -70,8 +90,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -70,8 +90,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
...
@@ -99,8 +125,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -99,8 +125,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>
;
tail_number_v
>
;
using
GemmPipeline
=
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CDataType
,
...
@@ -140,7 +165,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -140,7 +165,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE
_V3
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
...
@@ -215,24 +240,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -215,24 +240,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
#endif
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
}
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
Run
(
ck_tile
::
bool_constant
<
fals
e
>
{},
Run
(
ck_tile
::
bool_constant
<
tru
e
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
else
else
{
{
std
::
ostringstream
err
;
Run
(
ck_tile
::
bool_constant
<
true
>
{},
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
<<
"
\"
is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
#endif
}
else
{
std
::
ostringstream
err
;
err
<<
"Num K loop must be larger than number of prefetech stages."
<<
"
\n
PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
ave_time
;
return
ave_time
;
...
...
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
b8d11559
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"local_eid"
,
"-1"
,
"a list of experts enabled as local expert. e.g.
\"
0,1,4,5
\"\n
"
"please make sure eid is in ascending order!"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int
kname
=
args
.
get_int
(
"kname"
);
int
kname
=
args
.
get_int
(
"kname"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
max_output_ids
=
int
max_output_ids
=
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return
false
;
return
false
;
}
}
bool
local_expert_masking
=
args
.
get_str
(
"local_eid"
)
!=
"-1"
;
auto
local_expert_masking_host
=
[
&
]()
{
if
(
local_expert_masking
)
{
auto
local_eid
=
args
.
get_int_vec
(
"local_eid"
);
// std::vector<int> v_ {num_experts, 0};
ck_tile
::
HostTensor
<
IndexType
>
v_
{{
num_experts
}};
v_
.
SetZero
();
for
(
auto
eid
:
local_eid
)
{
if
(
eid
>=
num_experts
)
{
throw
std
::
runtime_error
(
"local_eid larger than number of expert, please check"
);
}
v_
.
mData
[
eid
]
=
1
;
}
return
v_
;
}
else
// return std::vector<int>{};
return
ck_tile
::
HostTensor
<
IndexType
>
{{
1
}};
}();
// tokens already considered batch size
// tokens already considered batch size
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
local_expert_masking_dev
(
local_expert_masking_host
.
get_element_space_size_in_bytes
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{
{
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
}
}
if
(
local_expert_masking
)
local_expert_masking_dev
.
ToDevice
(
local_expert_masking_host
.
data
());
moe_sorting_trait
trait
{
index_prec
,
weight_prec
};
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
local_expert_masking
};
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
local_expert_masking
?
local_expert_masking_dev
.
GetDeviceBuffer
()
:
nullptr
,
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup
,
warmup
,
repeat
};
repeat
};
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d,
ms:%f ,
"
,
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d, "
,
index_prec
.
c_str
(),
index_prec
.
c_str
(),
weight_prec
.
c_str
(),
weight_prec
.
c_str
(),
tokens
,
tokens
,
num_experts
,
num_experts
,
topk
,
topk
);
ms
);
if
(
local_expert_masking
)
{
printf
(
"local_eid:%s, "
,
args
.
get_str
(
"local_eid"
).
c_str
());
}
if
(
ms
<
0
)
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
printf
(
"not supported
\n
"
);
else
printf
(
"ms:%f, "
,
ms
);
fflush
(
stdout
);
fflush
(
stdout
);
if
(
ms
<
0
)
if
(
ms
<
0
)
{
{
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t
ref_total_tokens_post_pad
=
0
;
int32_t
ref_total_tokens_post_pad
=
0
;
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
weights_host
,
weights_host
,
local_expert_masking_host
,
sorted_ids_ref
,
sorted_ids_ref
,
sorted_weights_ref
,
sorted_weights_ref
,
sorted_expert_ids_ref
,
sorted_expert_ids_ref
,
ref_total_tokens_post_pad
,
ref_total_tokens_post_pad
,
num_experts
,
num_experts
,
unit_size
);
unit_size
,
local_expert_masking
);
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
}
}
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
printf
(
"total_tokens_post_pad:%d(%d), "
,
ref_total_tokens_post_pad
,
sorted_id_cnt_host
.
mData
[
0
]);
}
}
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
printf
(
"valid:%s"
,
rtn
?
"y"
:
"n"
);
fflush
(
stdout
);
if
(
!
rtn
)
printf
(
", (%d)"
,
seed
);
printf
(
"
\n
"
);
fflush
(
stdout
);
fflush
(
stdout
);
return
rtn
;
return
rtn
;
}
}
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
b8d11559
...
@@ -3,6 +3,12 @@
...
@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp"
#include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
@@ -17,6 +23,67 @@
...
@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
if(a.num_experts <= 8) \
{ \
{ \
...
@@ -38,11 +105,13 @@
...
@@ -38,11 +105,13 @@
{ \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
}
#endif
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
if
(
a
.
num_experts
>
127
)
{
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
}
return
-
1
;
return
-
1
;
}
}
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
View file @
b8d11559
...
@@ -10,7 +10,8 @@
...
@@ -10,7 +10,8 @@
struct
moe_sorting_trait
struct
moe_sorting_trait
{
{
std
::
string
index_type
;
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
};
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
b8d11559
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
11
-e
=
256
-k
=
5
$EXE
-t
=
64
-e
=
455
-k
=
8
$EXE
-t
=
777
-e
=
802
-k
=
99
$EXE
-t
=
4097
-e
=
906
-k
=
51
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
13
-e
=
64
-k
=
3
-local_eid
=
4,5,6,7,8,9,10,11
$EXE
-t
=
99
-e
=
33
-k
=
9
-local_eid
=
6,10,11,15,19
$EXE
-t
=
80
-e
=
99
-k
=
10
-local_eid
=
0,8,12,33
$EXE
-t
=
11
-e
=
256
-k
=
5
-local_eid
=
99,110,129
example/ck_tile/15_fused_moe/README.md
View file @
b8d11559
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
// * this could be larger than actual, since actual tokens are on GPU
//
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
example/ck_tile/15_fused_moe/fused_moe.hpp
View file @
b8d11559
...
@@ -8,14 +8,15 @@
...
@@ -8,14 +8,15 @@
struct
fused_moe_args
struct
fused_moe_args
{
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
local_expert_mask_ptr
;
// [e], local_expert_mask_ptr for EP
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
int
activation
;
// 0:gelu, 1:silu
int
activation
;
// 0:gelu, 1:silu
int
gate_only
;
// 0:g1u0, 1:g1u1
int
gate_only
;
// 0:g1u0, 1:g1u1
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool
local_expert_masking
;
// if mask experts as local expert
};
};
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/fused_moesorting.hpp
View file @
b8d11559
...
@@ -10,7 +10,8 @@
...
@@ -10,7 +10,8 @@
struct
fused_moesorting_trait
struct
fused_moesorting_trait
{
{
std
::
string
index_type
;
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
};
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
View file @
b8d11559
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return
1
;
return
1
;
}();
}();
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
};
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
,
t
.
local_expert_masking
};
auto
a0
=
fused_moesorting_args
{
auto
a0
=
fused_moesorting_args
{
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
local_expert_mask_ptr
,
// const void* p_local_expert_mask;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
b8d11559
...
@@ -3,6 +3,12 @@
...
@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp"
#include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
@@ -17,6 +23,67 @@
...
@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
if(a.num_experts <= 8) \
{ \
{ \
...
@@ -38,11 +105,13 @@
...
@@ -38,11 +105,13 @@
{ \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
}
#endif
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
if
(
a
.
num_experts
>
127
)
{
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
@@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
...
@@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
}
return
-
1
;
return
-
1
;
}
}
example/ck_tile/15_fused_moe/main.cpp
View file @
b8d11559
...
@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
hidden_size
;
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
bool
local_expert_masking
=
false
;
// TODO...
// w0 (Gate+Up or Gate only, N size)
// w0 (Gate+Up or Gate only, N size)
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
...
@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
IndexDataType
>
local_expert_mask_host
({
experts
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
...
@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
local_expert_mask_buf
(
local_expert_mask_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
...
@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m
,
block_m
,
activation
,
activation
,
gate_only
,
gate_only
,
fused_quant
};
fused_quant
,
local_expert_masking
};
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
...
@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
local_expert_masking
?
local_expert_mask_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
...
@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_ids_host
,
topk_weight_host
,
topk_weight_host
,
local_expert_mask_host
,
sorted_token_ids_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
,
local_expert_masking
);
if
(
activation
==
0
)
if
(
activation
==
0
)
{
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
...
@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_ids_host
,
topk_weight_host
,
topk_weight_host
,
local_expert_mask_host
,
sorted_token_ids_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
,
local_expert_masking
);
// done, preparing GPU buffer
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
);
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
b8d11559
...
@@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
}
}
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
b8d11559
...
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
b8d11559
...
@@ -118,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
...
@@ -118,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel
: "
<<
GroupedGemmKernel
::
GetName
()
<<
"
with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
b8d11559
...
@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
<<
std
::
endl
;
}
}
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
return
pass
;
return
pass
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
b8d11559
...
@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
if
(
!
ck
::
is_xdl_supported
())
...
@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
);
}
}
// polymorphic
// polymorphic
...
@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
return
str
.
str
();
}
}
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
{
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
#endif
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
b8d11559
...
@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// if workspace is not allocated
// if workspace is not allocated
if
(
!
arg
.
p_workspace_
)
if
(
!
arg
.
p_workspace_
)
{
{
std
::
cerr
<<
"Warning: Workspace for "
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
{
"allocated, use SetWorkSpacePointer."
std
::
cout
<<
"Warning: Workspace for "
<<
std
::
endl
;
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
if
(
!
ck
::
is_xdl_supported
())
if
(
!
ck
::
is_xdl_supported
())
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
b8d11559
...
@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
ABDataType
,
half_t
>::
value
||
is_same
<
ABDataType
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
lcm_AK1_BK1
,
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
,
ABDataType
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
b8d11559
...
@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
auto
lcm_AK1_BK1
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
bool
is_single_rate_mfma
=
((
is_same
<
FloatAB
,
half_t
>::
value
||
is_same
<
FloatAB
,
bhalf_t
>::
value
)
&&
lcm_AK1_BK1
<=
4
)
?
true
:
false
;
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
lcm_AK1_BK1
,
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
,
FloatAB
,
is_single_rate_mfma
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment