Commit 19c18624 authored by fsx950223's avatar fsx950223
Browse files

add multi embeddings support

parent 10c72ace
......@@ -9,7 +9,7 @@
#include <ctime>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -18,13 +18,6 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
// using EmbType = float;
// using IndexType = int64_t;
// using GammaDataType = float;
// using BetaDataType = float;
// using AccDataType = float;
// using OutType = float;
using EmbType = ck::half_t;
using IndexType = int64_t;
using GammaDataType = ck::half_t;
......@@ -32,47 +25,172 @@ using BetaDataType = ck::half_t;
using AccDataType = float;
using OutType = ck::half_t;
// clang-format off
// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize
using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 1>;
using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 1>;
using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 1>;
using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 4>;
using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 4>;
using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 4>;
using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 16384, 1, 4>;
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 2>;
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 2>;
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 2>;
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 2>;
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 8>;
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 8>;
template<typename emb_type, ck::index_t dim> struct emb_kernel{};
template<> struct emb_kernel<float, 256> { using kernel_type = DeviceInstance_fp32_e256; };
template<> struct emb_kernel<float, 512> { using kernel_type = DeviceInstance_fp32_e512; };
template<> struct emb_kernel<float, 768> { using kernel_type = DeviceInstance_fp32_e768; };
template<> struct emb_kernel<float, 1024> { using kernel_type = DeviceInstance_fp32_e1024;};
template<> struct emb_kernel<float, 1536> { using kernel_type = DeviceInstance_fp32_e1536;};
template<> struct emb_kernel<float, 2048> { using kernel_type = DeviceInstance_fp32_e2048;};
template<> struct emb_kernel<float, 4096> { using kernel_type = DeviceInstance_fp32_e4096;};
template<> struct emb_kernel<float, 8192> { using kernel_type = DeviceInstance_fp32_e8192;};
template<> struct emb_kernel<float, 16384>{ using kernel_type = DeviceInstance_fp32_e16384;};
template<> struct emb_kernel<ck::half_t, 256> { using kernel_type = DeviceInstance_fp16_e256; };
template<> struct emb_kernel<ck::half_t, 512> { using kernel_type = DeviceInstance_fp16_e512; };
template<> struct emb_kernel<ck::half_t, 768> { using kernel_type = DeviceInstance_fp16_e768; };
template<> struct emb_kernel<ck::half_t, 1024> { using kernel_type = DeviceInstance_fp16_e1024; };
template<> struct emb_kernel<ck::half_t, 1536> { using kernel_type = DeviceInstance_fp16_e1536; };
template<> struct emb_kernel<ck::half_t, 2048> { using kernel_type = DeviceInstance_fp16_e2048; };
template<> struct emb_kernel<ck::half_t, 4096> { using kernel_type = DeviceInstance_fp16_e4096; };
template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInstance_fp16_e8192; };
using DeviceInstance_fp16_e256 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
256,
1,
1,
3>;
using DeviceInstance_fp16_e512 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
512,
1,
2,
3>;
using DeviceInstance_fp16_e768 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
768,
1,
1,
3>;
using DeviceInstance_fp16_e1024 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
1024,
1,
2,
3>;
using DeviceInstance_fp16_e1536 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
1536,
1,
2,
3>;
using DeviceInstance_fp16_e2048 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
2048,
1,
2,
3>;
using DeviceInstance_fp16_e4096 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
4096,
1,
8,
3>;
using DeviceInstance_fp16_e8192 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
8192,
1,
8,
3>;
template <typename emb_type, ck::index_t dim>
struct emb_kernel
{
};
template <>
struct emb_kernel<ck::half_t, 256>
{
using kernel_type = DeviceInstance_fp16_e256;
};
template <>
struct emb_kernel<ck::half_t, 512>
{
using kernel_type = DeviceInstance_fp16_e512;
};
template <>
struct emb_kernel<ck::half_t, 768>
{
using kernel_type = DeviceInstance_fp16_e768;
};
template <>
struct emb_kernel<ck::half_t, 1024>
{
using kernel_type = DeviceInstance_fp16_e1024;
};
template <>
struct emb_kernel<ck::half_t, 1536>
{
using kernel_type = DeviceInstance_fp16_e1536;
};
template <>
struct emb_kernel<ck::half_t, 2048>
{
using kernel_type = DeviceInstance_fp16_e2048;
};
template <>
struct emb_kernel<ck::half_t, 4096>
{
using kernel_type = DeviceInstance_fp16_e4096;
};
template <>
struct emb_kernel<ck::half_t, 8192>
{
using kernel_type = DeviceInstance_fp16_e8192;
};
// clang-format on
......@@ -152,19 +270,19 @@ int main()
beta_dev.ToDevice(beta.mData.data());
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(),
emb_a_dev.GetDeviceBuffer(),
emb_b_dev.GetDeviceBuffer(),
emb_c_dev.GetDeviceBuffer(),
index_a_dev.GetDeviceBuffer(),
index_b_dev.GetDeviceBuffer(),
index_c_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
num_rows,
current_dim,
index_length,
epsilon);
auto argument_ptr = device_instance.MakeArgumentPointer(
out_dev.GetDeviceBuffer(),
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
current_dim,
index_length,
epsilon);
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
<< std::endl
<< std::flush;
......
......@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp"
namespace ck {
namespace tensor_operation {
......@@ -30,10 +30,10 @@ template <typename EmbType,
ck::index_t DimPerBlock,
ck::index_t RowPerBlock,
ck::index_t DimThreadSize,
ck::index_t RowVectorSize>
struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
{
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
{
return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
......@@ -42,28 +42,18 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct Argument : public BaseArgument
{
Argument(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const ck::index_t NumRows,
const ck::index_t EmbeddingDim,
const ck::index_t IndexLength,
const AccDataType epsilon)
: p_out_(p_out),
p_emb_a_(p_emb_a),
p_emb_b_(p_emb_b),
p_emb_c_(p_emb_c),
p_index_a_(p_index_a),
p_index_b_(p_index_b),
p_index_c_(p_index_c),
p_embs_(p_embs),
p_indexs_(p_indexs),
p_gamma_(p_gamma),
p_beta_(p_beta),
NumRows_(NumRows),
EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength),
epsilon_(epsilon)
......@@ -72,15 +62,10 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
}
OutType* p_out_;
const EmbType* p_emb_a_;
const EmbType* p_emb_b_;
const EmbType* p_emb_c_;
const IndexType* p_index_a_;
const IndexType* p_index_b_;
const IndexType* p_index_c_;
ck::Array<EmbType*, NumEmbeddings> p_embs_;
ck::Array<IndexType*, NumEmbeddings> p_indexs_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
ck::index_t NumRows_;
ck::index_t EmbeddingDim_;
ck::index_t IndexLength_;
AccDataType epsilon_;
......@@ -88,37 +73,28 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
size_t grid_size_;
};
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(void* p_out,
const void* p_emb_a,
const void* p_emb_b,
const void* p_emb_c,
const void* p_index_a,
const void* p_index_b,
const void* p_index_c,
const void* p_gamma,
const void* p_beta,
ck::index_t NumRows,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon)
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_out,
const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const void* p_gamma,
const void* p_beta,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon)
{
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
reinterpret_cast<const EmbType*>(p_emb_a),
reinterpret_cast<const EmbType*>(p_emb_b),
reinterpret_cast<const EmbType*>(p_emb_c),
reinterpret_cast<const IndexType*>(p_index_a),
reinterpret_cast<const IndexType*>(p_index_b),
reinterpret_cast<const IndexType*>(p_index_c),
p_embs,
p_indexs,
reinterpret_cast<const GammaDataType*>(p_gamma),
reinterpret_cast<const BetaDataType*>(p_beta),
NumRows,
EmbeddingDim,
IndexLength,
epsilon);
}
using GridwiseSparseEmbedding =
GridwiseSparseEmbedding3ForwardLayernorm<EmbType,
GridwiseSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
......@@ -131,7 +107,8 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
DimPerBlock,
RowPerBlock,
DimThreadSize,
RowVectorSize>;
RowVectorSize,
NumEmbeddings>;
struct Invoker : public BaseInvoker
{
......@@ -139,14 +116,15 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
const auto kernel_main =
kernel_sparse_embedding3_forward_layernorm<GridwiseSparseEmbedding,
kernel_sparse_embeddings_forward_layernorm<GridwiseSparseEmbedding,
EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(out_desc)>;
decltype(out_desc),
NumEmbeddings>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
......@@ -154,12 +132,8 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3(BlockSize),
0,
arg.p_out_,
arg.p_emb_a_,
arg.p_emb_b_,
arg.p_emb_c_,
arg.p_index_a_,
arg.p_index_b_,
arg.p_index_c_,
arg.p_embs_,
arg.p_indexs_,
arg.p_gamma_,
arg.p_beta_,
out_desc,
......@@ -177,7 +151,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static bool IsSupportedArgument(const Argument* p_arg)
{
return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0);
return (RowPerBlock == p_arg->EmbeddingDim_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -195,7 +169,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto str = std::stringstream();
// clang-format off
str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" <<
str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" <<
DimClusterSize << "x" << RowClusterSize << "_" <<
DimPerBlock << "x" << RowPerBlock << "_" <<
DimThreadSize << "x" << RowVectorSize;
......
......@@ -17,33 +17,21 @@ template <typename GridwiseSparseEmbedding,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename OutGridDesc>
typename OutGridDesc,
ck::index_t NumEmbeddings>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
__global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon)
__global__ void kernel_sparse_embeddings_forward_layernorm(
OutType* p_out,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon)
{
GridwiseSparseEmbedding::Run(p_out,
p_emb_a,
p_emb_b,
p_emb_c,
p_index_a,
p_index_b,
p_index_c,
p_gamma,
p_beta,
out_grid_desc,
epsilon);
GridwiseSparseEmbedding::Run(p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon);
}
template <typename EmbType,
......@@ -59,8 +47,9 @@ template <typename EmbType,
ck::index_t DimPerBlock, // Row x Dim, along Dim
ck::index_t RowPerBlock, // Row x Dim, along Row
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
ck::index_t RowVectorSize>
struct GridwiseSparseEmbedding3ForwardLayernorm
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct GridwiseSparseEmbeddingsForwardLayernorm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -85,8 +74,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
using ThreadwiseWolfordDescReduce = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
......@@ -97,12 +86,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
__device__ static void Run(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc,
......@@ -111,9 +96,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr auto thread_cluster_desc =
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
......@@ -141,13 +123,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto gamma_beta_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_a;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_b;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_c;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_a;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_b;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_c;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true>,
NumEmbeddings>
in_thread_bufs;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings>
index_bufs;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
......@@ -160,42 +140,30 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_a;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_b;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_c;
using src_vector_t = typename decltype(emb_vector_a)::type;
ck::Array<vector_type_maker_t<EmbType, RowVectorSize>, NumEmbeddings> emb_vectors;
auto emb_a = emb_vectors[0];
using src_vector_t = typename decltype(emb_a)::type;
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
IndexType index_a = index_buf_a[Number<current_dim>{}];
IndexType index_b = index_buf_b[Number<current_dim>{}];
IndexType index_c = index_buf_c[Number<current_dim>{}];
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(EmbType) * RowVectorSize;
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
IndexType index = index_bufs[i_embedding_.value][Number<current_dim>{}];
int32x4_t emb_res_a =
make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock);
int32x4_t emb_res_b =
make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock);
int32x4_t emb_res_c =
make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock);
emb_vector_a.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_a, thread_offset, 0);
emb_vector_b.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_b, thread_offset, 0);
emb_vector_c.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_c, thread_offset, 0);
int32x4_t emb_res = make_wave_buffer_resource_with_default_range(
p_embs[i_embedding_.value] + index * RowPerBlock);
emb_vectors(i_embedding_.value).template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
});
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
in_thread_buf_a(Number<register_offset>{}) =
emb_vector_a.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_b(Number<register_offset>{}) =
emb_vector_b.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_c(Number<register_offset>{}) =
emb_vector_c.template AsType<EmbType>()[i_row_vec_];
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
in_thread_bufs(i_embedding_.value)(Number<register_offset>{}) =
emb_vectors[i_embedding_.value].template AsType<EmbType>()[i_row_vec_];
});
});
});
};
......@@ -205,14 +173,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
AccDataType va =
ck::type_convert<AccDataType>(in_thread_buf_a(Number<register_offset>{}));
AccDataType vb =
ck::type_convert<AccDataType>(in_thread_buf_b(Number<register_offset>{}));
AccDataType vc =
ck::type_convert<AccDataType>(in_thread_buf_c(Number<register_offset>{}));
acc_thread_buf(Number<register_offset>{}) += va + vb + vc;
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
acc_thread_buf(Number<register_offset>{}) += ck::type_convert<AccDataType>(
in_thread_bufs(i_embedding_.value)(Number<register_offset>{}));
});
});
});
};
......@@ -273,9 +237,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
// prefer use s_load
index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value];
index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value];
index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value];
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_.value)(i_idx_) =
p_indexes[i_embedding_.value][index_start + i_idx_.value];
});
});
// load gamma/beta
......@@ -329,7 +294,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment