Commit e6f44a47 authored by rocking's avatar rocking
Browse files

Change the example to f16

parent 7df98b04
add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp)
add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp)
......@@ -19,13 +19,13 @@
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using DYDataType = ck::half_t;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t;
using DBetaDataType = ck::half_t;
using DXDataType = ck::half_t;
using DGammaDataType = float;
using DBetaDataType = float;
using DXDataType = float;
using ComputeDataType = float;
constexpr int Rank = 2;
......@@ -64,17 +64,17 @@ using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImp
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
8, // KThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
8, // DYSrcVectorSize
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
8, // XSrcVectorSize
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
8, // GammaSrcVectorSize
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
8>; // DXDstVectorSize
4>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType,
......@@ -88,16 +88,16 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
8, // MThreadSliceSize
4, // MThreadSliceSize
1, // KThreadSliceSize
false, // IsDYFastestDimReduced
8, // DYSrcVectorSize
4, // DYSrcVectorSize
false, // IsXFastestDimReduced
8, // XSrcVectorSize
4, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
8, // DGammaDstVectorSize
8>; // DBetaDstVectorSize
4, // DGammaDstVectorSize
4>; // DBetaDstVectorSize
int main()
{
......
add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp)
add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp)
......@@ -19,13 +19,13 @@
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using DYDataType = ck::half_t;
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t;
using DBetaDataType = ck::half_t;
using DXDataType = ck::half_t;
using DGammaDataType = float;
using DBetaDataType = float;
using DXDataType = float;
using ComputeDataType = float;
constexpr int Rank = 5;
......@@ -54,17 +54,17 @@ using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImp
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
8, // KThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
8, // DYSrcVectorSize
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
8, // XSrcVectorSize
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
8, // GammaSrcVectorSize
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
8>; // DXDstVectorSize
4>; // DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W
......@@ -89,16 +89,16 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
256, // BlockSize
8, // ClusterInvarient
32, // ClusterReduce
8, // SliceInvarient
4, // SliceInvarient
1, // SliceReduce
false, // IsDYFastestDimReduced
8, // DYSrcVectorSize
4, // DYSrcVectorSize
false, // IsXFastestDimReduced
8, // XSrcVectorSize
4, // XSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
8, // DGammaDstVectorSize
8>; // DBetaDstVectorSize
4, // DGammaDstVectorSize
4>; // DBetaDstVectorSize
int main()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment