Commit 66cff910 authored by coderfeli's avatar coderfeli
Browse files

merge gemm1 and gemm2

parents aa15c49a 2e53f972
...@@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m ...@@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) # target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_gemm_fp16 moe_gemm_fp16.cpp) add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
This diff is collapsed.
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp"
#include "ck/utility/is_detected.hpp" #include "ck/utility/is_detected.hpp"
namespace ck { namespace ck {
...@@ -42,30 +42,35 @@ template <typename ThreadGroup, ...@@ -42,30 +42,35 @@ template <typename ThreadGroup,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags, typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags, typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3 struct ThreadGroupTensorSliceTransfer_v7r3
{ {
static constexpr index_t nDim = static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension(); remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size(); static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size(); static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3(
const SrcDescs& src_descs, const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins, const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs, const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins, const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets)
: threadwise_transfer_(src_descs, : threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{}, StaticallyIndexedArray<Index, nSrc>{},
dst_descs, dst_descs,
StaticallyIndexedArray<Index, nDst>{}, StaticallyIndexedArray<Index, nDst>{},
element_op) element_op,
scatter_offsets)
{ {
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
...@@ -100,17 +105,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -100,17 +105,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple( const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, [&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; },
Number<nSrc>{}); Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto dst_thread_slice_origins = generate_tuple( const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, [&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
Number<nDst>{}); Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
...@@ -197,7 +201,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -197,7 +201,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3<SrcDatas, ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
DstDatas, DstDatas,
SrcDescs, SrcDescs,
DstDescs, DstDescs,
...@@ -212,6 +216,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -212,6 +216,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
DstScalarPerVector, DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
NumThreadScratch>; NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
......
...@@ -279,7 +279,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -279,7 +279,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
...@@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm2 : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids},
sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k},
b_e_n_k_{b_e_n_k},
c_t_n_{c_t_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_t_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t sorted_tile_size_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMoeGemm2::Argument;
float Run(const Argument& arg)
{
arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0];
if(t < token_cnt) {
for(int k = 0; k < K; ++k)
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_e_n_k_(e, n, k));
}
else
{
arg.b_element_op_(v_b, arg.b_e_n_k_(e, n, k));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
arg.c_t_n_(t, n) += v_c;
}
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.a_m_k_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])(
1);
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMoeGemm2"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O1 -g --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment