Unverified Commit 03b8119e authored by rocking's avatar rocking Committed by GitHub
Browse files

Add Normalization splitk instances (#829)

* Add normalization splitK to layernorm and groupnorm instances

* Fix bug of GetKPerThread()

* Refine naming

* clang format
parent a5343db0
...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st ...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{}; static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int __device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{ {
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block) if(is_rightmost_block)
{ {
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = int kPerThread = kRightmostBlock < K_BlockTileSize
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); ? 0
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st ...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
} }
else else
{ {
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int kPerBlock = math::integer_divide_ceil(k, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize); return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
} }
} }
...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st ...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ = threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); kRaw,
k_grid_size,
block_k_cluster_id,
thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f); mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f16_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f16_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Pass, 5, 3>{}); device_normalization_f16_generic_instance<Pass, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{}); add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Pass, 5, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f32_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f32_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Pass, 5, 3>{}); device_normalization_f32_generic_instance<Pass, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 5, 3>{}); add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Pass, 5, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -18,6 +18,8 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( ...@@ -18,6 +18,8 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
instances, device_normalization_f16_f32_f32_f16_generic_instance<Swish, 5, 3>{}); instances, device_normalization_f16_f32_f32_f16_generic_instance<Swish, 5, 3>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{}); device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{});
add_device_operation_instances(
instances, device_normalization_splitk_f16_f32_f32_f16_instances<Swish, 5, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f16_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Swish, 5, 3>{}); device_normalization_f16_generic_instance<Swish, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{}); add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Swish, 5, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f32_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Swish, 5, 3>{}); device_normalization_f32_generic_instance<Swish, 5, 3>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Swish, 5, 3>{}); add_device_operation_instances(instances, device_normalization_f32_instances<Swish, 5, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Swish, 5, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f16_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f16_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Pass, 2, 1>{}); device_normalization_f16_generic_instance<Pass, 2, 1>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 2, 1>{}); add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 2, 1>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Pass, 2, 1>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f32_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f32_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Pass, 2, 1>{}); device_normalization_f32_generic_instance<Pass, 2, 1>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 2, 1>{}); add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 2, 1>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Pass, 2, 1>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f16_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f16_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f16_generic_instance<Pass, 4, 3>{}); device_normalization_f16_generic_instance<Pass, 4, 3>{});
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 4, 3>{}); add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 4, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f16_instances<Pass, 4, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f32_instances( ...@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f32_instances(
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_normalization_f32_generic_instance<Pass, 4, 3>{}); device_normalization_f32_generic_instance<Pass, 4, 3>{});
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 4, 3>{}); add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 4, 3>{});
add_device_operation_instances(instances,
device_normalization_splitk_f32_instances<Pass, 4, 3>{});
} }
} // namespace instance } // namespace instance
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -43,6 +44,32 @@ using device_normalization_f16_instances = ...@@ -43,6 +44,32 @@ using device_normalization_f16_instances =
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple< using device_normalization_f16_generic_instance = std::tuple<
// clang-format off // clang-format off
...@@ -76,6 +103,32 @@ using device_normalization_f32_instances = std::tuple< ...@@ -76,6 +103,32 @@ using device_normalization_f32_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_generic_instance = std::tuple< using device_normalization_f32_generic_instance = std::tuple<
// clang-format off // clang-format off
...@@ -109,6 +162,32 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple< ...@@ -109,6 +162,32 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple< using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
// clang-format off // clang-format off
......
...@@ -139,6 +139,10 @@ bool profile_groupnorm_impl(int do_verification, ...@@ -139,6 +139,10 @@ bool profile_groupnorm_impl(int do_verification,
continue; continue;
} }
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer(); auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
......
...@@ -155,6 +155,10 @@ bool profile_layernorm_impl(int do_verification, ...@@ -155,6 +155,10 @@ bool profile_layernorm_impl(int do_verification,
continue; continue;
} }
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer(); auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment