Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bca3f14c
Commit
bca3f14c
authored
Feb 18, 2025
by
coderfeli
Browse files
fix nswizzle=0
parent
e78fbf87
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
5 deletions
+11
-5
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+2
-1
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
+6
-3
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
bca3f14c
...
...
@@ -139,6 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
Nswizzle
=
true
;
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
...
...
@@ -174,7 +175,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
MXDLPerWave
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
D0Vec
,
D1Vec
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
Nswizzle
,
true
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
bca3f14c
...
...
@@ -169,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave
,
1
,
S
<
1
,
CShuffleMLane
,
1
,
CShuffleNLane
>
,
S
<
EVec
,
D0Vec
,
D1Vec
,
D2Vec
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
false
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
false
,
false
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
bca3f14c
...
...
@@ -66,6 +66,7 @@ template <typename ALayout,
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
bool
NSwizzle
=
false
,
bool
IsInputGemm
=
true
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
...
...
@@ -133,6 +134,7 @@ struct DeviceMoeGemm
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
NSwizzle
,
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
View file @
bca3f14c
...
...
@@ -142,6 +142,7 @@ template <typename ALayout,
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
bool
NSwizzle
=
false
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ADataType
,
...
...
@@ -197,9 +198,11 @@ struct GridwiseMoeGemm
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
return
std
::
make_tuple
(
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
math
::
integer_divide_ceil
(
M
,
MPerBlock
),
1
,
1
);
const
index_t
nblock
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
mblock
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
index_t
gridx
=
NSwizzle
?
nblock
*
mblock
:
nblock
;
const
index_t
gridy
=
NSwizzle
?
1
:
mblock
;
return
std
::
make_tuple
(
gridx
,
gridy
,
1
);
}
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment