Commit 87df7683 authored by fsx950223's avatar fsx950223
Browse files

change to elementwise op

parent 935422a4
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp" #include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -25,16 +25,16 @@ using GammaDataType = ck::half_t; ...@@ -25,16 +25,16 @@ using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using OutType = ck::half_t; using OutType = ck::half_t;
using ReduceOperation = ck::tensor_operation::element_wise::Add; using ElementwiseOperation = ck::tensor_operation::element_wise::AddAdd;
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 256, 1, 1, 3>; using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 256, 1, 1, 3>;
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 512, 1, 2, 3>; using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 512, 1, 2, 3>;
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 768, 1, 1, 3>; using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 768, 1, 1, 3>;
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 1024, 1, 2, 3>; using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 1024, 1, 2, 3>;
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 1536, 1, 2, 3>; using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 1536, 1, 2, 3>;
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 2048, 1, 2, 3>; using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 2048, 1, 2, 3>;
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 4096, 1, 8, 3>; using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 4096, 1, 8, 3>;
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ReduceOperation, 256, 1, 256, 1, 8192, 1, 8, 3>; using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, ElementwiseOperation, 256, 1, 256, 1, 8192, 1, 8, 3>;
template<typename emb_type, ck::index_t dim> struct emb_kernel{}; template<typename emb_type, ck::index_t dim> struct emb_kernel{};
...@@ -137,7 +137,7 @@ int main() ...@@ -137,7 +137,7 @@ int main()
current_dim, current_dim,
index_length, index_length,
epsilon, epsilon,
ReduceOperation{}); ElementwiseOperation{});
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
<< std::endl << std::endl
<< std::flush; << std::flush;
......
...@@ -24,7 +24,7 @@ template <typename EmbType, ...@@ -24,7 +24,7 @@ template <typename EmbType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename ReduceOperation, typename ElementwiseOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t DimClusterSize, ck::index_t DimClusterSize,
ck::index_t RowClusterSize, ck::index_t RowClusterSize,
...@@ -50,7 +50,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -50,7 +50,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
const ck::index_t EmbeddingDim, const ck::index_t EmbeddingDim,
const ck::index_t IndexLength, const ck::index_t IndexLength,
const AccDataType epsilon, const AccDataType epsilon,
const ReduceOperation reduce_op) const ElementwiseOperation elementwise_op)
: p_out_(p_out), : p_out_(p_out),
p_embs_(p_embs), p_embs_(p_embs),
p_indexs_(p_indexs), p_indexs_(p_indexs),
...@@ -59,7 +59,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -59,7 +59,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
EmbeddingDim_(EmbeddingDim), EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength), IndexLength_(IndexLength),
epsilon_(epsilon), epsilon_(epsilon),
reduce_op_(reduce_op) reduce_op_(elementwise_op)
{ {
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
} }
...@@ -72,7 +72,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -72,7 +72,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
ck::index_t EmbeddingDim_; ck::index_t EmbeddingDim_;
ck::index_t IndexLength_; ck::index_t IndexLength_;
AccDataType epsilon_; AccDataType epsilon_;
ReduceOperation reduce_op_; ElementwiseOperation reduce_op_;
size_t grid_size_; size_t grid_size_;
}; };
...@@ -86,7 +86,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -86,7 +86,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
ck::index_t EmbeddingDim, ck::index_t EmbeddingDim,
ck::index_t IndexLength, ck::index_t IndexLength,
const AccDataType epsilon, const AccDataType epsilon,
const ReduceOperation reduce_op) const ElementwiseOperation elementwise_op)
{ {
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out), return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
p_embs, p_embs,
...@@ -96,7 +96,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -96,7 +96,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
EmbeddingDim, EmbeddingDim,
IndexLength, IndexLength,
epsilon, epsilon,
reduce_op); elementwise_op);
} }
using GridwiseSparseEmbedding = using GridwiseSparseEmbedding =
...@@ -107,7 +107,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -107,7 +107,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
AccDataType, AccDataType,
OutType, OutType,
decltype(MakeOutputDescriptor(1, 1)), decltype(MakeOutputDescriptor(1, 1)),
ReduceOperation, ElementwiseOperation,
BlockSize, BlockSize,
DimClusterSize, DimClusterSize,
RowClusterSize, RowClusterSize,
...@@ -131,7 +131,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator ...@@ -131,7 +131,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
AccDataType, AccDataType,
OutType, OutType,
decltype(out_desc), decltype(out_desc),
ReduceOperation, ElementwiseOperation,
NumEmbeddings>; NumEmbeddings>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
......
...@@ -18,7 +18,7 @@ template <typename GridwiseSparseEmbedding, ...@@ -18,7 +18,7 @@ template <typename GridwiseSparseEmbedding,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename OutGridDesc, typename OutGridDesc,
typename ReduceOperation, typename ElementwiseOperation,
ck::index_t NumEmbeddings> ck::index_t NumEmbeddings>
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -31,9 +31,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ...@@ -31,9 +31,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const BetaDataType* p_beta, const BetaDataType* p_beta,
const OutGridDesc out_grid_desc, const OutGridDesc out_grid_desc,
const AccDataType epsilon, const AccDataType epsilon,
const ReduceOperation reduce_op) const ElementwiseOperation elementwise_op)
{ {
GridwiseSparseEmbedding::Run(p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, reduce_op); GridwiseSparseEmbedding::Run(
p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, elementwise_op);
} }
template <typename EmbType, template <typename EmbType,
...@@ -43,7 +44,7 @@ template <typename EmbType, ...@@ -43,7 +44,7 @@ template <typename EmbType,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename OutGridDesc, typename OutGridDesc,
typename ReduceOperation, typename ElementwiseOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t DimClusterSize, ck::index_t DimClusterSize,
ck::index_t RowClusterSize, ck::index_t RowClusterSize,
...@@ -95,7 +96,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -95,7 +96,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
const BetaDataType* p_beta, const BetaDataType* p_beta,
const OutGridDesc, const OutGridDesc,
const AccDataType epsilon, const AccDataType epsilon,
const ReduceOperation reduce_op) const ElementwiseOperation elementwise_op)
{ {
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -127,7 +128,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -127,7 +128,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
constexpr auto gamma_beta_buf_desc = constexpr auto gamma_beta_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize)); make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true>, ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true>,
NumEmbeddings> NumEmbeddings>
in_thread_bufs; in_thread_bufs;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings> ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings>
...@@ -166,7 +167,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -166,7 +167,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
in_thread_bufs(i_embedding_)(Number<register_offset>{}) = in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]; ck::type_convert<AccDataType>(
emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
}); });
}); });
}); });
...@@ -177,10 +179,17 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -177,10 +179,17 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset( constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { auto in_data_refs = generate_tie(
reduce_op(acc_thread_buf(Number<register_offset>{}), acc_thread_buf(Number<register_offset>{}), ck::type_convert<AccDataType>( [&](auto i_embedding_) -> const auto& {
in_thread_bufs(i_embedding_)(Number<register_offset>{}))); return in_thread_bufs(i_embedding_)(Number<register_offset>{});
}); },
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto output_index_) -> auto& {
return acc_thread_buf(Number<register_offset>{});
},
Number<1>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
}); });
}); });
}; };
...@@ -210,7 +219,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -210,7 +219,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
constexpr auto mean_var_offset = constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon); auto divisor =
1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset( constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment