Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
96047cab
Commit
96047cab
authored
Feb 17, 2025
by
coderfeli
Browse files
impl e swizzel
parent
56cc306d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
42 deletions
+62
-42
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+8
-7
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+14
-6
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+21
-21
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+17
-8
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+2
-0
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
96047cab
...
@@ -191,11 +191,11 @@ int main(int argc, char* argv[])
...
@@ -191,11 +191,11 @@ int main(int argc, char* argv[])
// experts = 8
// experts = 8
// per expert:
// per expert:
// GEMM shape
// GEMM shape
ck
::
index_t
N
=
6
14
4
;
ck
::
index_t
N
=
14
336
*
2
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_num
=
16
;
ck
::
index_t
valid_tile_num
=
8
;
ck
::
index_t
valid_tile_num
=
13
;
ck
::
index_t
sorted_size
=
sorted_tile_num
*
MPerBlock
;
ck
::
index_t
sorted_size
=
sorted_tile_num
*
MPerBlock
;
ck
::
index_t
valid_size
=
valid_tile_num
*
MPerBlock
;
ck
::
index_t
valid_size
=
valid_tile_num
*
MPerBlock
;
ck
::
index_t
tokens
=
64
;
ck
::
index_t
tokens
=
64
;
...
@@ -243,10 +243,11 @@ int main(int argc, char* argv[])
...
@@ -243,10 +243,11 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8;
// const ck::index_t experts = 8;
Tensor
<
ck
::
index_t
>
expert_ids
(
HostTensorDescriptor
({
sorted_tile_num
},
{
1
}));
Tensor
<
ck
::
index_t
>
expert_ids
(
HostTensorDescriptor
({
sorted_tile_num
},
{
1
}));
Tensor
<
ck
::
index_t
>
sorted_token_ids
(
HostTensorDescriptor
({
sorted_size
},
{
1
}));
Tensor
<
ck
::
index_t
>
sorted_token_ids
(
HostTensorDescriptor
({
sorted_size
},
{
1
}));
Tensor
<
ck
::
index_t
>
max_token_id
(
HostTensorDescriptor
({
1
}));
Tensor
<
ck
::
index_t
>
max_token_id
(
HostTensorDescriptor
({
1
+
sorted_tile_num
}));
max_token_id
.
mData
[
0
]
=
valid_size
;
max_token_id
.
mData
=
{
valid_size
,
2
,
2
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
1
,
2
,
2
,
0
,
0
,
0
};
int
eids
[]
=
{
0
,
0
,
1
,
2
,
3
,
3
,
4
,
4
,
5
,
5
,
6
,
7
,
7
,
3
,
3
,
3
};
// {2, 1, 1, 2, 2, 2, 1, 2}
for
(
int
i
=
0
;
i
<
sorted_tile_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
sorted_tile_num
;
i
++
)
{
expert_ids
.
mData
[
i
]
=
i
;
expert_ids
.
mData
[
i
]
=
eids
[
i
]
;
}
}
int
token_per_tile
=
tokens
*
topk
/
valid_tile_num
;
int
token_per_tile
=
tokens
*
topk
/
valid_tile_num
;
int
tokenid
=
0
;
int
tokenid
=
0
;
...
...
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
96047cab
...
@@ -186,20 +186,27 @@ int main(int argc, char* argv[])
...
@@ -186,20 +186,27 @@ int main(int argc, char* argv[])
// experts = 8
// experts = 8
// per expert:
// per expert:
// GEMM shape
// GEMM shape
ck
::
index_t
N
=
6144
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
K
=
14336
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
1
0
;
ck
::
index_t
sorted_tile_num
=
1
6
;
ck
::
index_t
valid_tile_num
=
8
;
ck
::
index_t
valid_tile_num
=
13
;
ck
::
index_t
sorted_size
=
sorted_tile_num
*
MPerBlock
;
ck
::
index_t
sorted_size
=
sorted_tile_num
*
MPerBlock
;
ck
::
index_t
valid_size
=
valid_tile_num
*
MPerBlock
;
ck
::
index_t
valid_size
=
valid_tile_num
*
MPerBlock
;
ck
::
index_t
tokens
=
64
;
ck
::
index_t
tokens
=
512
;
ck
::
index_t
topk
=
2
;
ck
::
index_t
topk
=
2
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default case
// use default case
}
}
else
if
(
argc
==
3
)
{
// use default case
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
7
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
...
@@ -233,8 +240,9 @@ int main(int argc, char* argv[])
...
@@ -233,8 +240,9 @@ int main(int argc, char* argv[])
Tensor
<
ck
::
index_t
>
sorted_token_ids
(
HostTensorDescriptor
({
sorted_size
},
{
1
}));
Tensor
<
ck
::
index_t
>
sorted_token_ids
(
HostTensorDescriptor
({
sorted_size
},
{
1
}));
Tensor
<
ck
::
index_t
>
max_token_id
(
HostTensorDescriptor
({
1
}));
Tensor
<
ck
::
index_t
>
max_token_id
(
HostTensorDescriptor
({
1
}));
max_token_id
.
mData
[
0
]
=
valid_size
;
max_token_id
.
mData
[
0
]
=
valid_size
;
int
eids
[]
=
{
0
,
0
,
1
,
2
,
3
,
3
,
4
,
4
,
5
,
5
,
6
,
7
,
7
,
3
,
3
,
3
};
for
(
int
i
=
0
;
i
<
sorted_tile_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
sorted_tile_num
;
i
++
)
{
expert_ids
.
mData
[
i
]
=
i
;
expert_ids
.
mData
[
i
]
=
eids
[
i
]
;
}
}
if
(
tokens
*
topk
>
valid_size
)
if
(
tokens
*
topk
>
valid_size
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
96047cab
...
@@ -346,27 +346,27 @@ struct DeviceMoeGemm
...
@@ -346,27 +346,27 @@ struct DeviceMoeGemm
// }
// }
// else
// else
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
//
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
//
{
if
constexpr
(
IsGatherGemm
)
{
//
if constexpr (IsGatherGemm) {
const
auto
kernel
=
kernel_moe_gemm_gather
<
//
const auto kernel = kernel_moe_gemm_gather<
GridwiseGemm
,
//
GridwiseGemm,
true
,
//
true,
InMemoryDataOperationEnum
::
Set
,
//
InMemoryDataOperationEnum::Set,
minimum_occupancy
,
//
minimum_occupancy,
TailNumber
::
Odd
>
;
//
TailNumber::Odd>;
RunKernel
(
kernel
);
//
RunKernel(kernel);
}
else
{
//
} else {
const
auto
kernel
=
kernel_moe_gemm_scatter
<
//
const auto kernel = kernel_moe_gemm_scatter<
GridwiseGemm
,
//
GridwiseGemm,
true
,
//
true,
InMemoryDataOperationEnum
::
AtomicAdd
,
//
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy
,
//
minimum_occupancy,
TailNumber
::
Odd
>
;
//
TailNumber::Odd>;
RunKernel
(
kernel
);
//
RunKernel(kernel);
}
//
}
}
//
}
else
//
else
{
{
if
constexpr
(
IsGatherGemm
)
{
if
constexpr
(
IsGatherGemm
)
{
const
auto
kernel
=
kernel_moe_gemm_gather
<
const
auto
kernel
=
kernel_moe_gemm_gather
<
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
96047cab
...
@@ -197,8 +197,8 @@ struct GridwiseMoeGemmGather
...
@@ -197,8 +197,8 @@ struct GridwiseMoeGemmGather
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
{
return
std
::
make_tuple
(
math
::
integer_divide_ceil
(
N
,
NPerBlock
),
return
std
::
make_tuple
(
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
,
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
,
1
,
1
);
1
);
}
}
...
@@ -1140,7 +1140,6 @@ struct GridwiseMoeGemmGather
...
@@ -1140,7 +1140,6 @@ struct GridwiseMoeGemmGather
ignore
=
b_element_op
;
ignore
=
b_element_op
;
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bpreshuffled
=
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
...
@@ -1150,11 +1149,21 @@ struct GridwiseMoeGemmGather
...
@@ -1150,11 +1149,21 @@ struct GridwiseMoeGemmGather
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
block_m_id
]);
const
index_t
max_token_id
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
0
]);
const
index_t
max_token_id
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
0
]);
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
const
index_t
expert_block_id
=
blockIdx
.
x
/
problem
.
NBlock
;
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
expert_block_id
]);
const
index_t
es
=
__builtin_amdgcn_readfirstlane
(
p_max_token_id
[
expert_block_id
+
1
]);
const
index_t
expert_swizzle
=
es
>
0
?
es
:
1
;
//p_max_token_id[expert_id + 1];
const
index_t
expert_block_swizzle
=
expert_block_id
/
expert_swizzle
;
const
index_t
b_block_id_swizzle
=
blockIdx
.
x
%
(
problem
.
NBlock
*
expert_swizzle
);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
b_block_id_swizzle
%
8
+
b_block_id_swizzle
/
(
8
*
expert_swizzle
)
*
8
);
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
expert_block_swizzle
*
expert_swizzle
+
b_block_id_swizzle
/
8
%
expert_swizzle
);
if
(
threadIdx
.
x
==
0
)
{
printf
(
"bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d
\n
"
,
blockIdx
.
x
,
expert_id
,
expert_swizzle
,
expert_block_swizzle
,
b_block_id_swizzle
,
block_m_id
,
block_n_id
);
}
const
index_t
token0
=
__builtin_amdgcn_readfirstlane
(
p_sorted_token_ids
[
block_m_id
*
MPerBlock
]
&
0xffffff
);
const
index_t
token0
=
__builtin_amdgcn_readfirstlane
(
p_sorted_token_ids
[
block_m_id
*
MPerBlock
]
&
0xffffff
);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
96047cab
...
@@ -573,9 +573,11 @@ struct MoeSortingKernel
...
@@ -573,9 +573,11 @@ struct MoeSortingKernel
{
{
int
e_start
=
cumsum
[
tid
];
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
int
e_end
=
cumsum
[
tid
+
1
];
int
e_size
=
unit_size_mdiv
.
div
(
e_end
-
e_start
+
unit_size_mdiv
.
divisor
-
1
);
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
p_sorted_expert_cnts
[]
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment