Commit 0f444e0c authored by mtgu0705's avatar mtgu0705
Browse files

Initial int4 moe, compile pass, function not check.

parent 2d256913
...@@ -2,3 +2,4 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult ...@@ -2,3 +2,4 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_pk_i4_gemm1 moe_pk_i4_gemm1.cpp)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp"
// #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I4 = ck::pk_i4_t;
using F16 = ck::half_t;
// using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using B0DataType = I4;
using EDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
// for gate, a_scale, b_scale
struct MulABScale
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>
(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
{
e = ck::type_convert<EDataType>(c * d1 * d0);
}
};
// for gate, a_scale, b_scale, fuse silu,
struct MulABScaleSilu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float>
(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
{
// act
float x0 = 0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
e = ck::type_convert<EDataType>(x0);
}
};
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using CDEElementOp = MulABScale;
// using CDEElementOp = MulABScaleSiluMulGate;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, MPerBlock, 128, KPerBlock,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
CShuffleMXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
// GEMM shape
ck::index_t N = 6144;
ck::index_t K = 8192;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 128;
if(argc == 1)
{
// use default case
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf(
"arg4 to 5: N, K\n");
exit(0);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t batch_stride_B = K * N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
ck::index_t KBatch = 1;
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1}));
for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i;
}
int token_per_tile = tokens / sorted_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for (int i = 0; i < SORTED_SIZE; i++) {
int tile_off = i % sorted_tile_size;
if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++;
else
sorted_token_ids.mData[i] = tokens;
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
// preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
printf("Start PreShuffle\n");
// weight pre-shuffle
int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8
int NLane = device_op.GetPreShuffleParameters();
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int e = 0; e < experts; ++e)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
b0_preshuffled(e * batch_stride_B + outputIndex) =
b0_e_n_k(e * batch_stride_B + n * K + k);
}
}
}
printf("End PreShuffle, and Start vector permute\n");
// vector pk_i4x4 permute
for(int e = 0; e < experts; e++)
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b0_preshuffled(e, j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 6, i) = i4x2;
}
}
}
}
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
printf("End Permute and Start GEMM\n");
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
SORTED_SIZE,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K;
std::size_t num_btype =
sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
}
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({SORTED_SIZE, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < SORTED_SIZE; ++m)
{
const int t = sorted_token_ids(m);
const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
e_m_n_device_result.savetxt("out.txt");
e_m_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <fstream>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include <utility> #include <utility>
...@@ -323,6 +324,33 @@ struct Tensor ...@@ -323,6 +324,33 @@ struct Tensor
{ {
} }
void savetxt(std::string file_name, std::string dtype = "float")
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
if(dtype == "float")
file << ck::type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << ck::type_convert<int>(itm) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file << ck::type_convert<float>(itm) << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
decltype(auto) GetLengths() const { return mDesc.GetLengths(); } decltype(auto) GetLengths() const { return mDesc.GetLengths(); }
decltype(auto) GetStrides() const { return mDesc.GetStrides(); } decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp"
namespace ck {
/**
* @brief Blockwise data transfer
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t GatherDim = 1,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_mod8
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
static constexpr index_t mod_num = ThreadClusterLengths{}.At(I0); // Dirty HACK FELIX, TODO fix
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op,
gather_offsets)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + src_thread_cluster_idx * thread_slice_lengths);
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + dst_thread_cluster_idx * thread_slice_lengths);
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx()
{
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
GatherDim,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template <typename ThreadGroup,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
typename SrcScalarPerVectors,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}) ; // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op,
scatter_offsets,
scatter_weights)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
"wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; },
Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index( OutputScatter ? ThreadGroup::GetThreadId() % mod_num : ThreadGroup::GetThreadId()));
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
}
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
}
template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVectors,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
OutputScatter,
ScatterWeightIdx,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
...@@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator ...@@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment