Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
12301455
Commit
12301455
authored
Feb 09, 2025
by
coderfeli
Browse files
gemm2 result ok
parent
7ba5bff4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
10 deletions
+12
-10
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+8
-6
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+2
-2
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
12301455
...
...
@@ -33,8 +33,8 @@ using F32 = float;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
F
16
;
using
B0DataType
=
F
16
;
using
A0DataType
=
F
8
;
using
B0DataType
=
F
8
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
F32
;
...
...
@@ -172,10 +172,10 @@ int main(int argc, char* argv[])
// experts = 8
// per expert:
// GEMM shape
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
2
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
32
;
...
...
@@ -341,6 +341,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
}
if
(
time_kernel
)
{
// not result correct here because output buf not setzero
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
SORTED_SIZE
*
N
*
K
;
...
...
@@ -357,9 +358,10 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
//gemm2 use atomic, so need to reinit outputs
e_device_buf
.
ToDevice
(
e_t_n_device_result
.
mData
.
data
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
,
0
,
1
});
// e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor
<
CShuffleDataType
>
c_t_n
({
tokens
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMoeGemm2
<
A0DataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
12301455
...
...
@@ -279,7 +279,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
...
...
@@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
12301455
...
...
@@ -1170,7 +1170,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
...
...
@@ -1397,7 +1397,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
printf
(
"init off tid %d m %d off %d
\n
"
,
threadIdx
.
x
,
m0
(),
g
at
h
er_offsets
(
m0
));
//
printf("init off
bid %d
tid %d m %d off %d\n",
blockIdx.y,
threadIdx.x, m0(),
sc
at
t
er_offsets(m0));
});
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment