Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72752420
Commit
72752420
authored
Feb 10, 2025
by
coderfeli
Browse files
merge gemm1 gemm2 together and run ok
parent
66cff910
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
286 additions
and
45 deletions
+286
-45
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+3
-3
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
+10
-15
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
...block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
+225
-0
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+46
-25
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
...k/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
+2
-2
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
72752420
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm
_multiple_d_xdl_cshuffle_v3_b_preshuffle
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_
moe_
gemm.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
...
@@ -127,7 +127,7 @@ static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
...
@@ -127,7 +127,7 @@ static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm
MultiD_Xdl_CShuffle_V3_BPreshuffle
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
Device
Moe
Gemm
// clang-format off
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...
@@ -156,7 +156,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
...
@@ -156,7 +156,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
EVec
,
EVec
,
1
>
,
CShuffleMXDLPerWave
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
EVec
,
EVec
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
false
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
View file @
72752420
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3
_scatter
.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -42,35 +42,30 @@ template <typename ThreadGroup,
...
@@ -42,35 +42,30 @@ template <typename ThreadGroup,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
ScatterDim
=
1
,
index_t
NumThreadScratch
=
1
>
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r3
struct
ThreadGroupTensorSliceTransfer_v7r3
{
{
static
constexpr
index_t
nDim
=
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
mod_num
=
ThreadClusterLengths
{}.
At
(
Number
<
3
>
{});
// Dirty HACK FELIX, TODO fix
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
index_t
scatter_num
=
thread_slice_lengths
.
At
(
Number
<
ScatterDim
>
{});
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r3
(
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r3
(
const
SrcDescs
&
src_descs
,
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
,
const
ElementwiseOperation
&
element_op
)
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
)
:
threadwise_transfer_
(
src_descs
,
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
,
element_op
)
scatter_offsets
)
{
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
...
@@ -105,16 +100,17 @@ struct ThreadGroupTensorSliceTransfer_v7r3
...
@@ -105,16 +100,17 @@ struct ThreadGroupTensorSliceTransfer_v7r3
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
const
auto
src_
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_slice_origins
=
generate_tuple
(
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
src_
thread_
cluster_idx
*
thread_slice_lengths
;
},
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_
data_idx_begin
;
},
Number
<
nSrc
>
{});
Number
<
nSrc
>
{});
const
auto
dst_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()
%
mod_num
));
const
auto
dst_thread_slice_origins
=
generate_tuple
(
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
dst_
thread_
cluster_idx
*
thread_slice_lengths
;
},
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_
data_idx_begin
;
},
Number
<
nDst
>
{});
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
...
@@ -201,7 +197,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
...
@@ -201,7 +197,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7r3
_scatter
<
SrcDatas
,
ThreadwiseTensorSliceTransfer_v7r3
<
SrcDatas
,
DstDatas
,
DstDatas
,
SrcDescs
,
SrcDescs
,
DstDescs
,
DstDescs
,
...
@@ -216,7 +212,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3
...
@@ -216,7 +212,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3
DstScalarPerVector
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
ScatterDim
,
NumThreadScratch
>
;
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
ThreadwiseTransfer
threadwise_transfer_
;
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
0 → 100644
View file @
72752420
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template
<
typename
ThreadGroup
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
typename
SrcScalarPerVectors
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
ScatterDim
=
1
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r3_scatter
{
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
mod_num
=
ThreadClusterLengths
{}.
At
(
Number
<
3
>
{});
// Dirty HACK FELIX, TODO fix
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
index_t
scatter_num
=
thread_slice_lengths
.
At
(
Number
<
ScatterDim
>
{});
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r3_scatter
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
,
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
,
scatter_offsets
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
src_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
src_thread_cluster_idx
*
thread_slice_lengths
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()
%
mod_num
));
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
dst_thread_cluster_idx
*
thread_slice_lengths
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
}
}
template
<
typename
SrcBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
,
thread_scratch_id
);
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
if
constexpr
(
is_detected
<
is_tuple
,
decltype
(
dst_bufs
)
>::
value
)
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
,
thread_scratch_id
);
else
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
),
thread_scratch_id
);
}
}
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
RunRead
(
src_descs
,
src_bufs
);
RunWrite
(
dst_descs
,
dst_bufs
);
}
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
step
)
{
static_for
<
0
,
SrcDescs
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
MoveSrcSliceWindow
(
src_descs
,
i
,
step
);
});
}
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_descs
,
iDst
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
step
)
{
static_for
<
0
,
DstDescs
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
MoveDstSliceWindow
(
dst_descs
,
i
,
step
);
});
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7r3_scatter
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVectors
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
ScatterDim
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
72752420
...
@@ -85,7 +85,7 @@ struct DeviceMoeGemm
...
@@ -85,7 +85,7 @@ struct DeviceMoeGemm
CElementwiseOperation
>
CElementwiseOperation
>
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
std
::
conditional_t
<
IsGatherGemm
,
using
GridwiseGemm
=
std
::
conditional_t
<
IsGatherGemm
,
GridwiseMoeGemmGather
<
GridwiseMoeGemmGather
<
ALayout
,
ALayout
,
BLayout
,
BLayout
,
...
@@ -218,7 +218,7 @@ struct DeviceMoeGemm
...
@@ -218,7 +218,7 @@ struct DeviceMoeGemm
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
Kernel
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
if
(
stream_config
.
flush_cache
)
{
{
...
@@ -301,6 +301,7 @@ struct DeviceMoeGemm
...
@@ -301,6 +301,7 @@ struct DeviceMoeGemm
// Tail number always full
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
{
using
meme
// if(arg.KBatch > 1)
// if(arg.KBatch > 1)
// {
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...
@@ -311,7 +312,7 @@ struct DeviceMoeGemm
...
@@ -311,7 +312,7 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// minimum_occupancy,
// TailNumber::Odd>;
// TailNumber::Odd>;
// Run(kernel);
// Run
Kernel
(kernel);
// }
// }
// else
// else
// {
// {
...
@@ -321,30 +322,50 @@ struct DeviceMoeGemm
...
@@ -321,30 +322,50 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// minimum_occupancy,
// TailNumber::Even>;
// TailNumber::Even>;
// Run(kernel);
// Run
Kernel
(kernel);
// }
// }
// }
// }
// else
// else
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
{
const
auto
kernel
=
kernel_moe_gemm_gather
<
if
constexpr
(
IsGatherGemm
)
{
GridwiseGemm
,
const
auto
kernel
=
kernel_moe_gemm_gather
<
true
,
GridwiseGemm
,
ScatterOutput
?
InMemoryDataOperationEnum
::
AtomicAdd
:
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Odd
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Odd
>
;
RunKernel
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_moe_gemm_scatter
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
RunKernel
(
kernel
);
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_moe_gemm_gather
<
if
constexpr
(
IsGatherGemm
)
{
GridwiseGemm
,
const
auto
kernel
=
kernel_moe_gemm_gather
<
true
,
GridwiseGemm
,
ScatterOutput
?
InMemoryDataOperationEnum
::
AtomicAdd
:
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Even
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Even
>
;
RunKernel
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_moe_gemm_scatter
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
RunKernel
(
kernel
);
}
}
}
}
}
}
}
...
@@ -361,7 +382,7 @@ struct DeviceMoeGemm
...
@@ -361,7 +382,7 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// minimum_occupancy,
// TailNumber::Odd>;
// TailNumber::Odd>;
// Run(kernel);
// Run
Kernel
(kernel);
// }
// }
// else
// else
// {
// {
...
@@ -372,7 +393,7 @@ struct DeviceMoeGemm
...
@@ -372,7 +393,7 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// minimum_occupancy,
// TailNumber::Even>;
// TailNumber::Even>;
// Run(kernel);
// Run
Kernel
(kernel);
// }
// }
// }
// }
// else
// else
...
@@ -383,10 +404,10 @@ struct DeviceMoeGemm
...
@@ -383,10 +404,10 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds<
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// GridwiseGemm,
// true,
// true,
// Sc
at
t
er
Output?
InMemoryDataOperationEnum::AtomicAdd
: InMemoryDataOperationEnum::Set
,
//
IsG
at
h
er
Gemm? InMemoryDataOperationEnum::Set :
InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// minimum_occupancy,
// TailNumber::Odd>;
// TailNumber::Odd>;
// Run(kernel);
// Run
Kernel
(kernel);
// }
// }
// else
// else
// {
// {
...
@@ -394,10 +415,10 @@ struct DeviceMoeGemm
...
@@ -394,10 +415,10 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds<
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// GridwiseGemm,
// true,
// true,
// Sc
at
t
er
Output?
InMemoryDataOperationEnum::AtomicAdd
: InMemoryDataOperationEnum::Set
,
//
IsG
at
h
er
Gemm? InMemoryDataOperationEnum::Set :
InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// minimum_occupancy,
// TailNumber::Even>;
// TailNumber::Even>;
// Run(kernel);
// Run
Kernel
(kernel);
// }
// }
// }
// }
// }
// }
...
@@ -414,7 +435,7 @@ struct DeviceMoeGemm
...
@@ -414,7 +435,7 @@ struct DeviceMoeGemm
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
-
1
;
//
Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
View file @
72752420
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3
_scatter
.hpp"
#define DEBUG_LOG 0
#define DEBUG_LOG 0
...
@@ -1404,7 +1404,7 @@ struct GridwiseMoeGemmScatter
...
@@ -1404,7 +1404,7 @@ struct GridwiseMoeGemmScatter
});
});
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
_scatter
<
ThisThreadBlock
,
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
Tuple
<
EDataType
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment