Commit e6f44a47 authored by rocking's avatar rocking
Browse files

Change the example to f16

parent 7df98b04
add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp) add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp)
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -64,17 +64,17 @@ using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImp ...@@ -64,17 +64,17 @@ using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImp
8, // MThreadClusterSize 8, // MThreadClusterSize
32, // KThreadClusterSize 32, // KThreadClusterSize
1, // MThreadSliceSize 1, // MThreadSliceSize
8, // KThreadSliceSize 4, // KThreadSliceSize
true, // IsDYFastestDimReduced true, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
true, // IsXFastestDimReduced true, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
true, // IsGammaFastestDimReduced true, // IsGammaFastestDimReduced
8, // GammaSrcVectorSize 4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced true, // IsDXFastestDimReduced
8>; // DXDstVectorSize 4>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType, DYDataType,
...@@ -88,16 +88,16 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -88,16 +88,16 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
256, // BlockSize 256, // BlockSize
8, // MThreadClusterSize 8, // MThreadClusterSize
32, // KThreadClusterSize 32, // KThreadClusterSize
8, // MThreadSliceSize 4, // MThreadSliceSize
1, // KThreadSliceSize 1, // KThreadSliceSize
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
8, // DGammaDstVectorSize 4, // DGammaDstVectorSize
8>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
......
add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp)
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 5; constexpr int Rank = 5;
...@@ -54,17 +54,17 @@ using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImp ...@@ -54,17 +54,17 @@ using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImp
8, // MThreadClusterSize 8, // MThreadClusterSize
32, // KThreadClusterSize 32, // KThreadClusterSize
1, // MThreadSliceSize 1, // MThreadSliceSize
8, // KThreadSliceSize 4, // KThreadSliceSize
true, // IsDYFastestDimReduced true, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
true, // IsXFastestDimReduced true, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
true, // IsGammaFastestDimReduced true, // IsGammaFastestDimReduced
8, // GammaSrcVectorSize 4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced true, // IsDXFastestDimReduced
8>; // DXDstVectorSize 4>; // DXDstVectorSize
// kernel 2: M , K // kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W // dy: N, H, W, G, C -> G * C, N * H * W
...@@ -89,16 +89,16 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -89,16 +89,16 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // ClusterInvarient
32, // ClusterReduce 32, // ClusterReduce
8, // SliceInvarient 4, // SliceInvarient
1, // SliceReduce 1, // SliceReduce
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
false, // IsMeanInvStdFastestDimReduced false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
8, // DGammaDstVectorSize 4, // DGammaDstVectorSize
8>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment