Commit 8a5bb9f3 authored by coderfeli's avatar coderfeli
Browse files

add files , build and run ok

parent bd64a30b
...@@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m ...@@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) # target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_gemm_fp16 moe_gemm_fp16.cpp) add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
// using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F16;
using B0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
{
// const float x0_f = c * d0 * d1;
const float x0_f = c;
// printf("epi %f\n", c);
e = ck::type_convert<EDataType>(x0_f);
}
// template <>
// __host__ __device__ constexpr void operator()<BF16, float, float, float>(BF16& e,
// const float& c,
// const float& d0,
// const float& d1) const
// {
// const float x0_f = c;
// // const float x0_f = c * d0 * d1;
// e = ck::type_convert<BF16>(x0_f);
// }
};
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
int KPack = 16 / sizeof(B0DataType);
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RCR
// kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
256, MPerBlock, 128, KPerBlock,
// ak1, bk1
AK1, BK1,
// mn_perxdl
32, 32,
// mn_xdlperwave
MXDLPerWave, 1,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
// GEMM shape
ck::index_t N = 128;
ck::index_t K = 1024;
ck::index_t experts = 1;
ck::index_t sorted_tile_num = 1;
ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf(
"arg4 to 5: N, K\n");
exit(0);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1}));
for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i;
}
int token_per_tile = tokens / sorted_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for (int i = 0; i < SORTED_SIZE; i++) {
int tile_off = i % sorted_tile_size;
if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++;
else
sorted_token_ids.mData[i] = tokens;
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
// Tensor<B0DataType> b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{}));
// Tensor<B0DataType> b0_preshuffled(
// f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_t_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_t_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_t_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_t_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{};
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
SORTED_SIZE,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K;
std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
}
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1});
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_n({tokens, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m)
{
const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_t_n_device_result.savetxt("out.txt");
e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}
...@@ -48,6 +48,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -48,6 +48,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
static constexpr index_t nDim = static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension(); remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size(); static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size(); static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
...@@ -101,7 +102,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ...@@ -101,7 +102,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -1109,12 +1109,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1109,12 +1109,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{ {
ignore = b_element_op; ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); problem.NumTokens, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
...@@ -1125,19 +1125,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1125,19 +1125,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
...@@ -1153,10 +1140,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1153,10 +1140,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// if(threadIdx.x==0) // if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n", // printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -1166,7 +1149,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1166,7 +1149,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -1187,15 +1170,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1187,15 +1170,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
1,
BlockwiseGemmPipe::GlobalBufferNum>( BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{});
gather_offsets);
// Thread-wise copy // Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
...@@ -1406,10 +1387,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1406,10 +1387,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock; c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / EMThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
...@@ -1423,7 +1414,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1423,7 +1414,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
...@@ -1439,9 +1430,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1439,9 +1430,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{c_ds_desc_refs, {c_ds_desc_refs,
idx_c_ds_block_begin, idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op}; c_element_op};
// if(threadIdx.x== 0)
// printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0));
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm2 : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids},
sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k},
b_e_n_k_{b_e_n_k},
c_t_n_{c_t_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_t_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t sorted_tile_size_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMoeGemm2::Argument;
float Run(const Argument& arg)
{
arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_m_k_.mDesc.GetLengths()[0];
if(t < token_cnt) {
for(int k = 0; k < K; ++k)
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_e_n_k_(e, n, k));
}
else
{
arg.b_element_op_(v_b, arg.b_e_n_k_(e, n, k));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
}
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
arg.c_t_n_(t, n) += v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_t_n_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMoeGemm2"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment