Commit 72752420 authored by coderfeli's avatar coderfeli
Browse files

merge gemm1 gemm2 together and run ok

parent 66cff910
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
...@@ -127,7 +127,7 @@ static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); ...@@ -127,7 +127,7 @@ static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off // clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -156,7 +156,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -156,7 +156,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1>, CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp"
#include "ck/utility/is_detected.hpp" #include "ck/utility/is_detected.hpp"
namespace ck { namespace ck {
...@@ -42,35 +42,30 @@ template <typename ThreadGroup, ...@@ -42,35 +42,30 @@ template <typename ThreadGroup,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags, typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags, typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3 struct ThreadGroupTensorSliceTransfer_v7r3
{ {
static constexpr index_t nDim = static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension(); remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size(); static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size(); static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3(
const SrcDescs& src_descs, const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins, const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs, const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins, const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op, const ElementwiseOperation& element_op)
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets)
: threadwise_transfer_(src_descs, : threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{}, StaticallyIndexedArray<Index, nSrc>{},
dst_descs, dst_descs,
StaticallyIndexedArray<Index, nDst>{}, StaticallyIndexedArray<Index, nDst>{},
element_op, element_op)
scatter_offsets)
{ {
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
...@@ -105,16 +100,17 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -105,16 +100,17 @@ struct ThreadGroupTensorSliceTransfer_v7r3
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple( const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; }, [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
Number<nSrc>{}); Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto dst_thread_slice_origins = generate_tuple( const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; }, [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
Number<nDst>{}); Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
...@@ -201,7 +197,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -201,7 +197,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas, ThreadwiseTensorSliceTransfer_v7r3<SrcDatas,
DstDatas, DstDatas,
SrcDescs, SrcDescs,
DstDescs, DstDescs,
...@@ -216,7 +212,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -216,7 +212,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3
DstScalarPerVector, DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
NumThreadScratch>; NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template <typename ThreadGroup,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
typename SrcScalarPerVectors,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op,
scatter_offsets)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
"wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; },
Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
}
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
}
template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVectors,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
...@@ -85,7 +85,7 @@ struct DeviceMoeGemm ...@@ -85,7 +85,7 @@ struct DeviceMoeGemm
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = std::conditional_t< IsGatherGemm, using GridwiseGemm = std::conditional_t<IsGatherGemm,
GridwiseMoeGemmGather< GridwiseMoeGemmGather<
ALayout, ALayout,
BLayout, BLayout,
...@@ -218,7 +218,7 @@ struct DeviceMoeGemm ...@@ -218,7 +218,7 @@ struct DeviceMoeGemm
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) { const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache) if(stream_config.flush_cache)
{ {
...@@ -301,6 +301,7 @@ struct DeviceMoeGemm ...@@ -301,6 +301,7 @@ struct DeviceMoeGemm
// Tail number always full // Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
using meme
// if(arg.KBatch > 1) // if(arg.KBatch > 1)
// { // {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -311,7 +312,7 @@ struct DeviceMoeGemm ...@@ -311,7 +312,7 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Odd>; // TailNumber::Odd>;
// Run(kernel); // RunKernel(kernel);
// } // }
// else // else
// { // {
...@@ -321,30 +322,50 @@ struct DeviceMoeGemm ...@@ -321,30 +322,50 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Even>; // TailNumber::Even>;
// Run(kernel); // RunKernel(kernel);
// } // }
// } // }
// else // else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = kernel_moe_gemm_gather< if constexpr (IsGatherGemm) {
GridwiseGemm, const auto kernel = kernel_moe_gemm_gather<
true, GridwiseGemm,
ScatterOutput? InMemoryDataOperationEnum::AtomicAdd : InMemoryDataOperationEnum::Set, true,
minimum_occupancy, InMemoryDataOperationEnum::Set,
TailNumber::Odd>; minimum_occupancy,
Run(kernel); TailNumber::Odd>;
RunKernel(kernel);
} else {
const auto kernel = kernel_moe_gemm_scatter<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
} }
else else
{ {
const auto kernel = kernel_moe_gemm_gather< if constexpr (IsGatherGemm) {
GridwiseGemm, const auto kernel = kernel_moe_gemm_gather<
true, GridwiseGemm,
ScatterOutput? InMemoryDataOperationEnum::AtomicAdd : InMemoryDataOperationEnum::Set, true,
minimum_occupancy, InMemoryDataOperationEnum::Set,
TailNumber::Even>; minimum_occupancy,
Run(kernel); TailNumber::Even>;
RunKernel(kernel);
} else {
const auto kernel = kernel_moe_gemm_scatter<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
} }
} }
} }
...@@ -361,7 +382,7 @@ struct DeviceMoeGemm ...@@ -361,7 +382,7 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Odd>; // TailNumber::Odd>;
// Run(kernel); // RunKernel(kernel);
// } // }
// else // else
// { // {
...@@ -372,7 +393,7 @@ struct DeviceMoeGemm ...@@ -372,7 +393,7 @@ struct DeviceMoeGemm
// InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Even>; // TailNumber::Even>;
// Run(kernel); // RunKernel(kernel);
// } // }
// } // }
// else // else
...@@ -383,10 +404,10 @@ struct DeviceMoeGemm ...@@ -383,10 +404,10 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds< // kernel_moe_gemm_gather_2lds<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// ScatterOutput? InMemoryDataOperationEnum::AtomicAdd : InMemoryDataOperationEnum::Set, // IsGatherGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Odd>; // TailNumber::Odd>;
// Run(kernel); // RunKernel(kernel);
// } // }
// else // else
// { // {
...@@ -394,10 +415,10 @@ struct DeviceMoeGemm ...@@ -394,10 +415,10 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds< // kernel_moe_gemm_gather_2lds<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// ScatterOutput? InMemoryDataOperationEnum::AtomicAdd : InMemoryDataOperationEnum::Set, // IsGatherGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Even>; // TailNumber::Even>;
// Run(kernel); // RunKernel(kernel);
// } // }
// } // }
// } // }
...@@ -414,7 +435,7 @@ struct DeviceMoeGemm ...@@ -414,7 +435,7 @@ struct DeviceMoeGemm
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return -1;//Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
#define DEBUG_LOG 0 #define DEBUG_LOG 0
...@@ -1404,7 +1404,7 @@ struct GridwiseMoeGemmScatter ...@@ -1404,7 +1404,7 @@ struct GridwiseMoeGemmScatter
}); });
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>, Tuple<EDataType>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment