Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1687fc98
"configs/multimodal/vscode:/vscode.git/clone" did not exist on "2f1949e7a1ef908dad9454a88d21472f9ab8dbc7"
Commit
1687fc98
authored
Feb 17, 2025
by
coderfeli
Browse files
chage ktile
parent
4404984a
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
476 additions
and
9 deletions
+476
-9
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+29
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+442
-0
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
1687fc98
...
@@ -133,12 +133,12 @@ using BElementOp = PassThrough;
...
@@ -133,12 +133,12 @@ using BElementOp = PassThrough;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MXDLPerWave
=
2
;
static
constexpr
ck
::
index_t
NXDLPerWave
=
2
;
static
constexpr
ck
::
index_t
BLOCKSIZE
=
256
;
static
constexpr
ck
::
index_t
BLOCKSIZE
=
256
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
CShuffleMXDLPerWave
=
MPerBlock
/
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
...
@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// mn_perxdl
// mn_perxdl
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
// mn_xdlperwave
// mn_xdlperwave
2
,
2
,
MXDLPerWave
,
NXDLPerWave
,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...
@@ -173,7 +173,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -173,7 +173,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
D0Vec
,
D1Vec
>
,
MXDLPerWave
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
D0Vec
,
D1Vec
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
1687fc98
...
@@ -296,7 +296,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -296,7 +296,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
}
}
}
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
...
@@ -351,7 +352,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -351,7 +352,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
}
else
else
{
{
throw
std
::
runtime_error
(
"todo: only v1
&
v
2
support now"
);
throw
std
::
runtime_error
(
"todo: only v1
v2 and
v
3
support now"
);
}
}
}
}
#if 0
#if 0
...
@@ -359,6 +360,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -359,6 +360,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
#if 0
if(arg.KBatch > 1)
if(arg.KBatch > 1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...
@@ -405,8 +407,29 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -405,8 +407,29 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
Run(kernel);
Run(kernel);
}
}
}
}
#endif
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2
|| BlkGemmPipelineVer == BlockGemmPipelineVersion::v3
)
{
{
if(arg.KBatch > 1)
if(arg.KBatch > 1)
{
{
...
@@ -602,7 +625,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -602,7 +625,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
}};
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
}};
// clang-format off
// clang-format off
str
<<
"DeviceGemmXdlUniversal"
str
<<
"DeviceGemmXdlUniversal"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
1687fc98
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment