Unverified Commit a69aa2a1 authored by rocking's avatar rocking Committed by GitHub
Browse files

layernorm and groupnorm backward data (#1083)

* rename folder

* Add type string

* Remove typo

* Add deviceOp to backward x

* Add comment to describe the behavior of backward normalization

* Add kernel function, prepare to implement

* implement generic kernel

* Check vector size

* Add sweep once pipeline for small reduce size

* Fix bug of KRaw_ error

* Fix bug of dx stride

* sanity check for mean and rstd

* backward x for groupnorm

* Add bwd x instance

* add layernorm 2d bwd gamma beta instances

* Change save mean var type from f32 to f16 in f16 mode

* Change the example to f16

* Add groupnorm bwd gamma beta instance

* Add groupnorm bwd x instance

* Fix naming

* Add layernorm bwd x ckprofiler

* Add groupnorm bwd x profiler

* clang format

* Rename bwd x to bwd data

* Fix bug of verification in profiler

* Add test of layernorm and groupnorm bwd data

* Add missing cmake

* Add layernorm2d bwd data

* rename fwd example

* Add groupnorm client example

* Fix typo. replace Invarient with Invariant

* Add checking before running the best instance
parent ad0a8e4c
set(DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES
device_groupnorm_bwd_gamma_beta_f32_instance.cpp
device_layernorm2d_bwd_gamma_beta_f16_instance.cpp
device_layernorm2d_bwd_gamma_beta_f32_instance.cpp)
add_instance_library(device_normalization_bwd_gamma_beta_instance ${DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_gamma_beta_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_groupnorm_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_groupnorm_bwd_gamma_beta_f32_instances{});
add_device_operation_instances(instances,
device_groupnorm_bwd_gamma_beta_f32_generic_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_gamma_beta_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f16_generic_instance<2, 1>{});
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f16_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_gamma_beta_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f32_generic_instance<2, 1>{});
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f32_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f16_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize>
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 256, 1, 256, 2, 1, false, 2, false, 2, true, 1, 2, 2>,
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 256, 1, 256, 4, 1, false, 4, false, 4, true, 1, 4, 4>,
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 256, 1, 256, 8, 1, false, 8, false, 8, true, 1, 8, 8>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 64, 1, 64, 1, 1, false, 1, false, 1, true, 1, 1, 1>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize>
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 2, 1, false, 2, false, 2, true, 1, 2, 2>,
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 4, 1, false, 4, false, 4, true, 1, 4, 4>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 64, 1, 64, 1, 1, false, 1, false, 1, true, 1, 1, 1>
// clang-format on
>;
using device_groupnorm_bwd_gamma_beta_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize>
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 2, 1, false, 2, false, 2, false, 1, 2, 2>,
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 4, 1, false, 4, false, 4, false, 1, 4, 4>
// clang-format on
>;
using device_groupnorm_bwd_gamma_beta_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, 5, 3, 64, 1, 64, 1, 1, false, 1, false, 1, false, 1, 1, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -11,7 +11,7 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 4, 3>>>&
instances)
{
add_device_operation_instances(instances,
......
......@@ -23,24 +23,24 @@ using device_normalization_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
......@@ -49,31 +49,31 @@ using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
namespace ck {
namespace profiler {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType>
bool profile_groupnorm_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
{
// we don't need DGamma and DBeta here, just for reference class
using DGammaDataType = DXDataType;
using DBetaDataType = DXDataType;
if(length.size() != 5)
return false;
index_t N = length[0];
index_t G = length[3];
index_t C = length[4];
std::vector<index_t> reduce_dim = {1, 2, 4};
std::vector<index_t> gammaLength = {G, C};
Tensor<DYDataType> dy(length);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma({G, C});
Tensor<MeanInvStdDataType> mean({N, G});
Tensor<MeanInvStdDataType> inv_std({N, G});
Tensor<DXDataType> dx(length);
Tensor<DXDataType> host_dx(length);
Tensor<DGammaDataType> host_dgamma({G, C});
Tensor<DBetaDataType> host_dbeta({G, C});
std::vector<index_t> strideDy =
std::vector<ck::index_t>{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<index_t> strideX = strideDy;
std::vector<index_t> strideDx = strideDy;
std::vector<index_t> strideGamma = {0, 0, 0, C, 1};
std::vector<index_t> strideMeanInvStd = {G, 0, 0, 1, 0};
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_1<DYDataType>{});
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
gamma.GenerateTensorValue(GeneratorTensor_1<GammaDataType>{});
mean.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
inv_std.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
dx.GenerateTensorValue(GeneratorTensor_1<DXDataType>{});
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_2<DYDataType>{-5, 5});
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5});
gamma.GenerateTensorValue(GeneratorTensor_2<GammaDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
inv_std.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
dx.GenerateTensorValue(GeneratorTensor_2<DXDataType>{-5, 5});
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DYDataType>{0, 1});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-0.5, 0.5});
mean.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
inv_std.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
dx.GenerateTensorValue(GeneratorTensor_3<DXDataType>{-0.5, 0.5});
}
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// add device normalization instances
using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceInstance =
ck::tensor_operation::host::ReferenceGroupnormBwd<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
DXDataType,
ComputeDataType>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideDy,
strideX,
strideGamma,
strideMeanInvStd,
strideMeanInvStd,
strideDx,
reduce_dim,
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) +
x.mDesc.GetElementSize() * sizeof(XDataType) +
gamma.mDesc.GetElementSize() * sizeof(GammaDataType) +
mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
dx.mDesc.GetElementSize() * sizeof(DXDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
dx_dev.FromDevice(dx.mData.data());
bool pass = ck::utils::check_err(
dx.mData, host_dx.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "dy : ", dy.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "host_dx : ", host_dx.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "dx : ", dx.mData, ",") << std::endl;
}
if(!pass)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s,"
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
namespace ck {
namespace profiler {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType,
index_t Rank>
bool profile_layernorm_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
{
// we don't need DGamma and DBeta here, just for reference class
using DGammaDataType = DXDataType;
using DBetaDataType = DXDataType;
if(length.size() != Rank || Rank < 2)
return false;
// Assume normalize dimension except for batch (first) dimension
std::vector<index_t> reduce_length{length.begin() + 1, length.end()};
std::vector<index_t> reduce_dim;
for(int i = 1; i < Rank; ++i)
reduce_dim.push_back(i);
Tensor<DYDataType> dy(length);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma(reduce_length);
Tensor<MeanInvStdDataType> mean({length[0]});
Tensor<MeanInvStdDataType> inv_std({length[0]});
Tensor<DXDataType> dx(length);
Tensor<DXDataType> host_dx(length);
Tensor<DGammaDataType> host_dgamma(reduce_length);
Tensor<DBetaDataType> host_dbeta(reduce_length);
std::vector<index_t> strideDy =
std::vector<ck::index_t>{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<index_t> strideX = strideDy;
std::vector<index_t> strideDx = strideDy;
std::vector<index_t> strideGamma = strideDy;
strideGamma[0] = 0;
std::vector<index_t> strideMeanInvStd{Rank, 0};
strideMeanInvStd[0] = 1;
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_1<DYDataType>{});
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
gamma.GenerateTensorValue(GeneratorTensor_1<GammaDataType>{});
mean.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
inv_std.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
dx.GenerateTensorValue(GeneratorTensor_1<DXDataType>{});
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_2<DYDataType>{-5, 5});
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5});
gamma.GenerateTensorValue(GeneratorTensor_2<GammaDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
inv_std.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
dx.GenerateTensorValue(GeneratorTensor_2<DXDataType>{-5, 5});
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DYDataType>{0, 1});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-0.5, 0.5});
mean.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
inv_std.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
dx.GenerateTensorValue(GeneratorTensor_3<DXDataType>{-0.5, 0.5});
}
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
constexpr int NumReduceDim = Rank - 1;
// add device normalization instances
using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernormBwd<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
DXDataType,
ComputeDataType>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideDy,
strideX,
strideGamma,
strideMeanInvStd,
strideMeanInvStd,
strideDx,
reduce_dim,
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) +
x.mDesc.GetElementSize() * sizeof(XDataType) +
gamma.mDesc.GetElementSize() * sizeof(GammaDataType) +
mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
dx.mDesc.GetElementSize() * sizeof(DXDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
dx_dev.FromDevice(dx.mData.data());
bool pass = ck::utils::check_err(
dx.mData, host_dx.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "dy : ", dy.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "host_dx : ", host_dx.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "dx : ", dx.mData, ",") << std::endl;
}
if(!pass)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s,"
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler
} // namespace ck
......@@ -16,7 +16,9 @@ set(PROFILER_SOURCES
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
......@@ -78,6 +80,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_w
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_groupnorm_bwd_data_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct groupnormBwdDataArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_groupnorm_bwd_data()
{
// eg: ckProfiler groupnorm_bwd_data 1 0 2 0 1 --length 1 16 16 32 40
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1 16 16 32 40) \n"
<< std::endl;
}
int profile_groupnorm_bwd_data(int argc, char* argv[])
{
if(argc <= 2)
{
print_help_groupnorm_bwd_data();
return 0;
}
groupnormBwdDataArgParser arg_parser;
// short unnamed options
const ck::DataTypeEnum data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
using F32 = float;
if(length.size() == 5)
{
if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_groupnorm_bwd_data_impl<F32, F32, F32, F32, F32, F32>(
do_verification, init_method, do_log, time_kernel, length);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{
throw std::runtime_error("length should be 5");
}
return 0;
}
REGISTER_PROFILER_OPERATION("groupnorm_bwd_data",
"Group Normalization",
profile_groupnorm_bwd_data);
......@@ -98,7 +98,7 @@ int profile_groupnorm(int argc, char* argv[])
}
else if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F32, false>(
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F16, false>(
do_verification, init_method, do_log, time_kernel, length);
}
else
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_layernorm_bwd_data_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct layernormBwdDataArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_layernorm_bwd_data()
{
// eg: ckProfiler layernorm_bwd_data 0 0 2 0 1 --length 1502 4096
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1024 1024) \n"
<< std::endl;
}
int profile_layernorm_bwd_data(int argc, char* argv[])
{
if(argc <= 2)
{
print_help_layernorm_bwd_data();
return 0;
}
layernormBwdDataArgParser arg_parser;
// short unnamed options
const ck::DataTypeEnum data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
using F16 = ck::half_t;
using F32 = float;
if(length.size() == 2)
{
constexpr int rank = 2;
if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_layernorm_bwd_data_impl<F16, F16, F16, F16, F32, F16, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_layernorm_bwd_data_impl<F32, F32, F32, F32, F32, F32, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{
throw std::runtime_error("not implemented yet");
}
return 0;
}
REGISTER_PROFILER_OPERATION("layernorm_bwd_data",
"Layer Normalization",
profile_layernorm_bwd_data);
......@@ -104,7 +104,7 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F16, false, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Float)
......@@ -125,4 +125,4 @@ int profile_layernorm(int argc, char* argv[])
return 0;
}
REGISTER_PROFILER_OPERATION("layernorm", "Layer Normalization", profile_layernorm);
REGISTER_PROFILER_OPERATION("layernorm_fwd", "Layer Normalization", profile_layernorm);
......@@ -140,6 +140,7 @@ add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization_fwd)
add_subdirectory(normalization_bwd_data)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
......
add_custom_target(test_normalization_bwd_data)
add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance)
add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32)
endif()
add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance)
add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_groupnorm_bwd_data_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestgroupnormBwdData : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using GammaDataType = std::tuple_element_t<2, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<3, Tuple>;
using ComputeDataType = std::tuple_element_t<4, Tuple>;
using DXDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, H, W, G, C], reduce H, W, C
std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
{1, 2, 3, 4, 5},
{256, 9, 9, 9, 9},
{1, 64, 64, 32, 10},
{1, 32, 32, 32, 20},
{1, 16, 16, 32, 40}};
for(auto length : lengths)
{
bool success = ck::profiler::profile_groupnorm_bwd_data_impl<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType>(
true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestgroupnormBwdData, KernelTypes);
TYPED_TEST(TestgroupnormBwdData, Test_FP32) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment