Commit 4b8f1249 authored by fsx950223's avatar fsx950223
Browse files

fix format

parent 19c18624
......@@ -25,172 +25,25 @@ using BetaDataType = ck::half_t;
using AccDataType = float;
using OutType = ck::half_t;
using DeviceInstance_fp16_e256 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
256,
1,
1,
3>;
using DeviceInstance_fp16_e512 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
512,
1,
2,
3>;
using DeviceInstance_fp16_e768 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
768,
1,
1,
3>;
using DeviceInstance_fp16_e1024 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
1024,
1,
2,
3>;
using DeviceInstance_fp16_e1536 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
1536,
1,
2,
3>;
using DeviceInstance_fp16_e2048 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
2048,
1,
2,
3>;
using DeviceInstance_fp16_e4096 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
4096,
1,
8,
3>;
using DeviceInstance_fp16_e8192 =
ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
256,
1,
256,
1,
8192,
1,
8,
3>;
template <typename emb_type, ck::index_t dim>
struct emb_kernel
{
};
template <>
struct emb_kernel<ck::half_t, 256>
{
using kernel_type = DeviceInstance_fp16_e256;
};
template <>
struct emb_kernel<ck::half_t, 512>
{
using kernel_type = DeviceInstance_fp16_e512;
};
template <>
struct emb_kernel<ck::half_t, 768>
{
using kernel_type = DeviceInstance_fp16_e768;
};
template <>
struct emb_kernel<ck::half_t, 1024>
{
using kernel_type = DeviceInstance_fp16_e1024;
};
template <>
struct emb_kernel<ck::half_t, 1536>
{
using kernel_type = DeviceInstance_fp16_e1536;
};
template <>
struct emb_kernel<ck::half_t, 2048>
{
using kernel_type = DeviceInstance_fp16_e2048;
};
template <>
struct emb_kernel<ck::half_t, 4096>
{
using kernel_type = DeviceInstance_fp16_e4096;
};
template <>
struct emb_kernel<ck::half_t, 8192>
{
using kernel_type = DeviceInstance_fp16_e8192;
};
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1, 3>;
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 2, 3>;
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1, 3>;
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 2, 3>;
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 2, 3>;
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 2, 3>;
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 8, 3>;
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 8, 3>;
template<typename emb_type, ck::index_t dim> struct emb_kernel{};
template<> struct emb_kernel<ck::half_t, 256> { using kernel_type = DeviceInstance_fp16_e256; };
template<> struct emb_kernel<ck::half_t, 512> { using kernel_type = DeviceInstance_fp16_e512; };
template<> struct emb_kernel<ck::half_t, 768> { using kernel_type = DeviceInstance_fp16_e768; };
template<> struct emb_kernel<ck::half_t, 1024> { using kernel_type = DeviceInstance_fp16_e1024; };
template<> struct emb_kernel<ck::half_t, 1536> { using kernel_type = DeviceInstance_fp16_e1536; };
template<> struct emb_kernel<ck::half_t, 2048> { using kernel_type = DeviceInstance_fp16_e2048; };
template<> struct emb_kernel<ck::half_t, 4096> { using kernel_type = DeviceInstance_fp16_e4096; };
template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInstance_fp16_e8192; };
// clang-format on
......@@ -270,19 +123,18 @@ int main()
beta_dev.ToDevice(beta.mData.data());
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
auto argument_ptr = device_instance.MakeArgumentPointer(
out_dev.GetDeviceBuffer(),
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
current_dim,
index_length,
epsilon);
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(),
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
current_dim,
index_length,
epsilon);
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
<< std::endl
<< std::flush;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment