"llm/vscode:/vscode.git/clone" did not exist on "646371f56dfadbf47dda4cd71ca7ca574c6130d2"
Commit a3e437ad authored by rocking's avatar rocking
Browse files

Add groupnorm bwd x instance

parent 8e7805c5
set(DEVICE_NORMALIZATION_BWD_X_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_BWD_X_INSTANCES
device_groupnorm_bwd_x_f32_instance.cpp
device_layernorm2d_bwd_x_f16_instance.cpp
device_layernorm2d_bwd_x_f32_instance.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_x_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_groupnorm_bwd_x_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdX<F32, F32, F32, F32, F32, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,
device_groupnorm_bwd_x_f32_generic_instance{});
add_device_operation_instances(instances, device_groupnorm_bwd_x_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -52,6 +52,21 @@ using device_layernorm_bwd_x_f32_generic_instance = std::tuple<
// clang-format on
>;
using device_groupnorm_bwd_x_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize>
DeviceNormalizationBwdXImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 1, 2, true, 2, true, 2, true, 2, false, 1, true, 2>,
DeviceNormalizationBwdXImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 1, 4, true, 4, true, 4, true, 4, false, 1, true, 4>
// clang-format on
>;
using device_groupnorm_bwd_x_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdXImpl<F32, F32, F32, F32, F32, F32, 5, 3, 64, 1, 64, 1, 1, true, 1, true, 1, true, 1, false, 1, true, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment