Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e4ca61f9
Commit
e4ca61f9
authored
Feb 11, 2025
by
coderfeli
Browse files
moe gemm2 scales ok
parent
66d08ea3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
12 deletions
+17
-12
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+8
-5
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
...k/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
+4
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
.../thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
+3
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
...ry/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
+2
-3
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
e4ca61f9
...
...
@@ -132,7 +132,6 @@ static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static
constexpr
ck
::
index_t
D0Vec
=
1
;
static
constexpr
ck
::
index_t
D1Vec
=
1
;
static
constexpr
ck
::
index_t
D2Vec
=
1
;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceMoeGemm
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
...
@@ -185,7 +184,7 @@ int main(int argc, char* argv[])
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
32
;
ck
::
index_t
tokens
=
64
;
if
(
argc
==
1
)
{
...
...
@@ -234,14 +233,13 @@ int main(int argc, char* argv[])
else
sorted_token_ids
.
mData
[
i
]
=
tokens
;
}
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
Tensor
<
A0DataType
>
a0_m_k
(
HostTensorDescriptor
({
SORTED_SIZE
,
K
},
{
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
D0DataType
>
d0_m_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]}));
Tensor
<
D2DataType
>
d2_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
0
}));
Tensor
<
D2DataType
>
d2_e_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
1
,
0
}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
e_t_n_device_result
.
SetZero
();
...
...
@@ -285,6 +283,11 @@ int main(int argc, char* argv[])
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_t_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_m_k
.
savetxt
(
"a.txt"
);
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
d0_m_n
.
savetxt
(
"d0_m_n.txt"
,
"int"
);
d1_e_n
.
savetxt
(
"d1_e_n.txt"
,
"int"
);
d2_e_n
.
savetxt
(
"d2_e_n.txt"
,
"int"
);
sorted_token_ids_dev
.
ToDevice
(
sorted_token_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
View file @
e4ca61f9
...
...
@@ -1382,10 +1382,12 @@ struct GridwiseMoeGemmScatter
// ascale M, 1; bscale E, N, 1, move ptr to E
if
(
i
.
value
==
1
)
{
ptr_
+=
expert_id
*
problem
.
StrideDs
[
1
]
*
problem
.
N
;
ptr_
+=
expert_id
*
(
problem
.
StrideDs
[
1
]
?
problem
.
StrideDs
[
1
]
*
problem
.
N
:
1
);
// if ( threadIdx.x ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p
_ds_grid
[
i
]
,
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
p
tr_
,
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
View file @
e4ca61f9
...
...
@@ -174,7 +174,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_
[
i
]);
oob_val
=
oob_val
&
is_src_valid
;
if
(
i
.
value
==
2
)
if
(
i
.
value
==
3
)
{
static_assert
(
SrcScalarPerVectors
{}[
Number
<
2
>
{}]
==
1
,
"scatter weight dim, should only one vec"
);
constexpr
auto
iScatter
=
SrcSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
...
...
@@ -187,7 +187,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
const
auto
tmp
=
src_bufs
[
i
].
template
Get
<
DataType
>(
src_coords_
[
i
].
GetOffset
(),
true
);
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
// if(i.value == 2)
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
DataType
>()(
j
)
=
tmp
;
});
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
View file @
e4ca61f9
...
...
@@ -89,6 +89,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const
int
t
=
arg
.
sorted_token_ids_
(
m
);
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
const
int
token_cnt
=
arg
.
c_t_n_
.
mDesc
.
GetLengths
()[
0
];
D2DataType
v_topk_w
=
arg
.
d2_
(
m
,
0
);
//expert
if
(
t
<
token_cnt
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
...
...
@@ -120,9 +121,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType
v_c
{
0
};
D0DataType
v_d0
=
arg
.
d0_
(
m
,
n
);
// a
D0DataType
v_d1
=
arg
.
d1_
(
e
,
n
);
// b
D0DataType
v_d2
=
arg
.
d2_
(
e
,
0
);
//expert
arg
.
c_element_op_
(
v_c
,
v_acc
,
v_d0
,
v_d1
,
v_d2
);
arg
.
c_element_op_
(
v_c
,
v_acc
,
v_d0
,
v_d1
,
v_topk_w
);
arg
.
c_t_n_
(
t
,
n
)
+=
v_c
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment