Commit 935422a4 authored by fsx950223's avatar fsx950223
Browse files

add reduce operation

parent 3679054a
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp" #include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -24,15 +25,16 @@ using GammaDataType = ck::half_t; ...@@ -24,15 +25,16 @@ using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using OutType = ck::half_t; using OutType = ck::half_t;
using ReduceOperation = ck::tensor_operation::element_wise::Add;
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1, 3>; using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 256, 1, 1, 3>;
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 2, 3>; using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 512, 1, 2, 3>;
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1, 3>; using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 768, 1, 1, 3>;
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 2, 3>; using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 1024, 1, 2, 3>;
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 2, 3>; using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 1536, 1, 2, 3>;
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 2, 3>; using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 2048, 1, 2, 3>;
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 8, 3>; using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 4096, 1, 8, 3>;
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 8, 3>; using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 8192, 1, 8, 3>;
template<typename emb_type, ck::index_t dim> struct emb_kernel{}; template<typename emb_type, ck::index_t dim> struct emb_kernel{};
...@@ -134,7 +136,8 @@ int main() ...@@ -134,7 +136,8 @@ int main()
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
current_dim, current_dim,
index_length, index_length,
epsilon); epsilon,
ReduceOperation{});
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
<< std::endl << std::endl
<< std::flush; << std::flush;
......
...@@ -24,6 +24,7 @@ template <typename EmbType, ...@@ -24,6 +24,7 @@ template <typename EmbType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename ReduceOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t DimClusterSize, ck::index_t DimClusterSize,
ck::index_t RowClusterSize, ck::index_t RowClusterSize,
...@@ -48,7 +49,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -48,7 +49,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
const BetaDataType* p_beta, const BetaDataType* p_beta,
const ck::index_t EmbeddingDim, const ck::index_t EmbeddingDim,
const ck::index_t IndexLength, const ck::index_t IndexLength,
const AccDataType epsilon) const AccDataType epsilon,
const ReduceOperation reduce_op)
: p_out_(p_out), : p_out_(p_out),
p_embs_(p_embs), p_embs_(p_embs),
p_indexs_(p_indexs), p_indexs_(p_indexs),
...@@ -56,7 +58,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -56,7 +58,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
p_beta_(p_beta), p_beta_(p_beta),
EmbeddingDim_(EmbeddingDim), EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength), IndexLength_(IndexLength),
epsilon_(epsilon) epsilon_(epsilon),
reduce_op_(reduce_op)
{ {
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
} }
...@@ -69,6 +72,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -69,6 +72,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
ck::index_t EmbeddingDim_; ck::index_t EmbeddingDim_;
ck::index_t IndexLength_; ck::index_t IndexLength_;
AccDataType epsilon_; AccDataType epsilon_;
ReduceOperation reduce_op_;
size_t grid_size_; size_t grid_size_;
}; };
...@@ -81,7 +85,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -81,7 +85,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
const void* p_beta, const void* p_beta,
ck::index_t EmbeddingDim, ck::index_t EmbeddingDim,
ck::index_t IndexLength, ck::index_t IndexLength,
const AccDataType epsilon) const AccDataType epsilon,
const ReduceOperation reduce_op)
{ {
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out), return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
p_embs, p_embs,
...@@ -90,7 +95,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -90,7 +95,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
reinterpret_cast<const BetaDataType*>(p_beta), reinterpret_cast<const BetaDataType*>(p_beta),
EmbeddingDim, EmbeddingDim,
IndexLength, IndexLength,
epsilon); epsilon,
reduce_op);
} }
using GridwiseSparseEmbedding = using GridwiseSparseEmbedding =
...@@ -101,6 +107,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -101,6 +107,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
AccDataType, AccDataType,
OutType, OutType,
decltype(MakeOutputDescriptor(1, 1)), decltype(MakeOutputDescriptor(1, 1)),
ReduceOperation,
BlockSize, BlockSize,
DimClusterSize, DimClusterSize,
RowClusterSize, RowClusterSize,
...@@ -124,6 +131,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -124,6 +131,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
AccDataType, AccDataType,
OutType, OutType,
decltype(out_desc), decltype(out_desc),
ReduceOperation,
NumEmbeddings>; NumEmbeddings>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -137,7 +145,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -137,7 +145,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
out_desc, out_desc,
arg.epsilon_); arg.epsilon_,
arg.reduce_op_);
return (avg_time); return (avg_time);
} }
......
...@@ -18,6 +18,7 @@ template <typename GridwiseSparseEmbedding, ...@@ -18,6 +18,7 @@ template <typename GridwiseSparseEmbedding,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename OutGridDesc, typename OutGridDesc,
typename ReduceOperation,
ck::index_t NumEmbeddings> ck::index_t NumEmbeddings>
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -29,9 +30,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ...@@ -29,9 +30,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
const OutGridDesc out_grid_desc, const OutGridDesc out_grid_desc,
const AccDataType epsilon) const AccDataType epsilon,
const ReduceOperation reduce_op)
{ {
GridwiseSparseEmbedding::Run(p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon); GridwiseSparseEmbedding::Run(p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, reduce_op);
} }
template <typename EmbType, template <typename EmbType,
...@@ -41,6 +43,7 @@ template <typename EmbType, ...@@ -41,6 +43,7 @@ template <typename EmbType,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename OutGridDesc, typename OutGridDesc,
typename ReduceOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t DimClusterSize, ck::index_t DimClusterSize,
ck::index_t RowClusterSize, ck::index_t RowClusterSize,
...@@ -91,7 +94,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -91,7 +94,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
const OutGridDesc, const OutGridDesc,
const AccDataType epsilon) const AccDataType epsilon,
const ReduceOperation reduce_op)
{ {
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -149,11 +153,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -149,11 +153,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(EmbType) * RowVectorSize; sizeof(EmbType) * RowVectorSize;
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
IndexType index = index_bufs[i_embedding_.value][Number<current_dim>{}]; IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
int32x4_t emb_res = make_wave_buffer_resource_with_default_range( int32x4_t emb_res = make_wave_buffer_resource_with_default_range(
p_embs[i_embedding_.value] + index * RowPerBlock); p_embs[i_embedding_] + index * RowPerBlock);
emb_vectors(i_embedding_.value).template AsType<src_vector_t>()(I0) = emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0); amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
}); });
...@@ -161,8 +165,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -161,8 +165,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
constexpr auto register_offset = thread_buf_desc.CalculateOffset( constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
in_thread_bufs(i_embedding_.value)(Number<register_offset>{}) = in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
emb_vectors[i_embedding_.value].template AsType<EmbType>()[i_row_vec_]; emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_];
}); });
}); });
}); });
...@@ -174,8 +178,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -174,8 +178,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
constexpr auto register_offset = thread_buf_desc.CalculateOffset( constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
acc_thread_buf(Number<register_offset>{}) += ck::type_convert<AccDataType>( reduce_op(acc_thread_buf(Number<register_offset>{}), acc_thread_buf(Number<register_offset>{}), ck::type_convert<AccDataType>(
in_thread_bufs(i_embedding_.value)(Number<register_offset>{})); in_thread_bufs(i_embedding_)(Number<register_offset>{})));
}); });
}); });
}); });
...@@ -237,8 +241,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -237,8 +241,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
// prefer use s_load // prefer use s_load
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_.value)(i_idx_) = index_bufs(i_embedding_)(i_idx_) =
p_indexes[i_embedding_.value][index_start + i_idx_.value]; p_indexes[i_embedding_][index_start + i_idx_.value];
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment