"...composable_kernel_rocm.git" did not exist on "4957d5a399a1c3f6bcf812c9e2fa104ed0ea7742"
Commit 0cb58fac authored by rocking's avatar rocking
Browse files

Add bwd x instance

parent 7a272c85
set(DEVICE_NORMALIZATION_BWD_X_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_BWD_X_INSTANCES
device_layernorm2d_bwd_x_f16_instance.cpp
device_layernorm2d_bwd_x_f32_instance.cpp)
add_instance_library(device_normalization_bwd_x_instance ${DEVICE_NORMALIZATION_BWD_X_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_x_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_x_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdX<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_x_f16_generic_instance<2, 1>{});
add_device_operation_instances(instances, device_layernorm_bwd_x_f16_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_x_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_x_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdX<F32, F32, F32, F32, F32, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_x_f32_generic_instance<2, 1>{});
add_device_operation_instances(instances, device_layernorm_bwd_x_f32_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_x_f16_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize>
DeviceNormalizationBwdXImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 2, true, 2, true, 2, true, 2, false, 1, true, 2>,
DeviceNormalizationBwdXImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 4, true, 4, true, 4, true, 4, false, 1, true, 4>,
DeviceNormalizationBwdXImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 8, true, 8, true, 8, true, 8, false, 1, true, 8>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_x_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdXImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 64, 1, 64, 1, 1, true, 1, true, 1, true, 1, false, 1, true, 1>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_x_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize>
DeviceNormalizationBwdXImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 2, true, 2, true, 2, true, 2, false, 1, true, 2>,
DeviceNormalizationBwdXImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 4, true, 4, true, 4, true, 4, false, 1, true, 4>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_x_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdXImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 64, 1, 64, 1, 1, true, 1, true, 1, true, 1, false, 1, true, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment