Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ff2394e
Commit
9ff2394e
authored
Feb 18, 2025
by
coderfeli
Committed by
mtgu0705
Feb 18, 2025
Browse files
fix swizzle = false
parent
a61084f4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
+2
-2
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
9ff2394e
...
@@ -139,7 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
...
@@ -139,7 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
Nswizzle
=
tru
e
;
static
constexpr
ck
::
index_t
Nswizzle
=
fals
e
;
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
View file @
9ff2394e
...
@@ -1173,12 +1173,11 @@ struct GridwiseMoeGemm
...
@@ -1173,12 +1173,11 @@ struct GridwiseMoeGemm
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
index_t
max_token_id
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
0
]);
const
index_t
max_token_id
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
0
]);
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
const
index_t
expert_block_id
=
blockIdx
.
x
/
problem
.
NBlock
;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
expert_block_id
]);
const
auto
block_mn
=
[
&
]()
->
std
::
pair
<
int
,
int
>
{
const
auto
block_mn
=
[
&
]()
->
std
::
pair
<
int
,
int
>
{
if
constexpr
(
NSwizzle
)
if
constexpr
(
NSwizzle
)
{
{
const
index_t
expert_block_id
=
blockIdx
.
x
/
problem
.
NBlock
;
const
index_t
es
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
expert_block_id
+
1
]);
const
index_t
es
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
expert_block_id
+
1
]);
const
index_t
expert_swizzle
=
es
>
0
?
es
:
1
;
//p_max_token_id[expert_id + 1];
const
index_t
expert_swizzle
=
es
>
0
?
es
:
1
;
//p_max_token_id[expert_id + 1];
const
index_t
expert_block_swizzle
=
expert_block_id
/
expert_swizzle
;
const
index_t
expert_block_swizzle
=
expert_block_id
/
expert_swizzle
;
...
@@ -1192,6 +1191,7 @@ struct GridwiseMoeGemm
...
@@ -1192,6 +1191,7 @@ struct GridwiseMoeGemm
}();
}();
const
index_t
block_n_id
=
block_mn
.
first
;
const
index_t
block_n_id
=
block_mn
.
first
;
const
index_t
block_m_id
=
block_mn
.
second
;
const
index_t
block_m_id
=
block_mn
.
second
;
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
block_m_id
]);
// if (threadIdx.x==0) {
// if (threadIdx.x==0) {
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// }
// }
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment