Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
45d1c52e
Commit
45d1c52e
authored
Feb 18, 2025
by
coderfeli
Browse files
hotfix moegemm2 nswizzle
parent
bca3f14c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
+16
-7
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
45d1c52e
...
@@ -169,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -169,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave
,
1
,
S
<
1
,
CShuffleMLane
,
1
,
CShuffleNLane
>
,
S
<
EVec
,
D0Vec
,
D1Vec
,
D2Vec
>
,
CShuffleMXDLPerWave
,
1
,
S
<
1
,
CShuffleMLane
,
1
,
CShuffleNLane
>
,
S
<
EVec
,
D0Vec
,
D1Vec
,
D2Vec
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
fals
e
,
false
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
tru
e
,
false
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
View file @
45d1c52e
...
@@ -1158,13 +1158,22 @@ struct GridwiseMoeGemm
...
@@ -1158,13 +1158,22 @@ struct GridwiseMoeGemm
const
index_t
expert_block_id
=
blockIdx
.
x
/
problem
.
NBlock
;
const
index_t
expert_block_id
=
blockIdx
.
x
/
problem
.
NBlock
;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
expert_block_id
]);
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
expert_block_id
]);
const
index_t
es
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
expert_block_id
+
1
]);
const
auto
block_mn
=
[
&
]()
->
std
::
pair
<
int
,
int
>
{
const
index_t
expert_swizzle
=
es
>
0
?
es
:
1
;
//p_max_token_id[expert_id + 1];
if
constexpr
(
NSwizzle
)
const
index_t
expert_block_swizzle
=
expert_block_id
/
expert_swizzle
;
{
const
index_t
b_block_id_swizzle
=
blockIdx
.
x
%
(
problem
.
NBlock
*
expert_swizzle
);
const
index_t
es
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
expert_block_id
+
1
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
b_block_id_swizzle
%
8
+
b_block_id_swizzle
/
(
8
*
expert_swizzle
)
*
8
);
const
index_t
expert_swizzle
=
es
>
0
?
es
:
1
;
//p_max_token_id[expert_id + 1];
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
expert_block_swizzle
*
expert_swizzle
+
b_block_id_swizzle
/
8
%
expert_swizzle
);
const
index_t
expert_block_swizzle
=
expert_block_id
/
expert_swizzle
;
const
index_t
b_block_id_swizzle
=
blockIdx
.
x
%
(
problem
.
NBlock
*
expert_swizzle
);
const
index_t
nid
=
__builtin_amdgcn_readfirstlane
(
b_block_id_swizzle
%
8
+
b_block_id_swizzle
/
(
8
*
expert_swizzle
)
*
8
);
const
index_t
mid
=
__builtin_amdgcn_readfirstlane
(
expert_block_swizzle
*
expert_swizzle
+
b_block_id_swizzle
/
8
%
expert_swizzle
);
return
{
nid
,
mid
};
}
else
{
return
{
blockIdx
.
x
,
blockIdx
.
y
};
}
}();
const
index_t
block_n_id
=
block_mn
.
first
;
const
index_t
block_m_id
=
block_mn
.
second
;
// if (threadIdx.x==0) {
// if (threadIdx.x==0) {
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// }
// }
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment