Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e21f36fc
Commit
e21f36fc
authored
Feb 09, 2025
by
coderfeli
Browse files
moegemm2 ok
parent
12301455
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
711 additions
and
8 deletions
+711
-8
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+2
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
+9
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+4
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
.../thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
+696
-0
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
e21f36fc
...
...
@@ -122,6 +122,7 @@ static constexpr ck::index_t MPerBlock = 32;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
static
constexpr
ck
::
index_t
CShuffleMXDLPerWave
=
MPerBlock
/
32
;
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
...
...
@@ -154,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
EVec
,
1
>
,
CShuffleMXDLPerWave
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
EVec
,
EVec
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
View file @
e21f36fc
...
...
@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3
_scatter
.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
...
...
@@ -42,6 +42,7 @@ template <typename ThreadGroup,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
ScatterDim
=
1
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r3
{
...
...
@@ -55,18 +56,21 @@ struct ThreadGroupTensorSliceTransfer_v7r3
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
index_t
scatter_num
=
thread_slice_lengths
.
At
(
Number
<
ScatterDim
>
{});
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r3
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
const
ElementwiseOperation
&
element_op
,
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
element_op
,
scatter_offsets
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
...
...
@@ -197,7 +201,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7r3
<
SrcDatas
,
ThreadwiseTensorSliceTransfer_v7r3
_scatter
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
...
...
@@ -212,6 +216,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
ScatterDim
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
e21f36fc
...
...
@@ -1392,7 +1392,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr
auto
EMThreads
=
CDEBlockTransferCluster
{}.
At
(
I0
)
*
CDEBlockTransferCluster
{}.
At
(
I1
);
constexpr
auto
EMRepeats
=
MPerBlock
/
EMThreads
;
constexpr
auto
ENThreads
=
CDEBlockTransferCluster
{}.
At
(
I2
)
*
CDEBlockTransferCluster
{}.
At
(
I3
);
static_assert
(
EMRepeats
==
1
,
"only support 1 line per thread now!"
);
//
static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
ENThreads
*
EMRepeats
;
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
...
...
@@ -1431,10 +1431,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
0
,
0
,
block_n_id
,
0
)),
c_element_op
};
c_element_op
,
scatter_offsets
};
// if(threadIdx.x== 0)
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
+
scatter_offsets
(
I0
)
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
-
scatter_offsets
(
I0
)
);
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
0 → 100644
View file @
e21f36fc
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment