Commit 0c823497 authored by muozturk's avatar muozturk
Browse files

merge

parents 334cfe1c 68f2b5e7
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
void add_device_image_to_column_nhwgc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, BF16, BF16, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_BF16
add_device_operation_instances(instances, device_image_to_column_bf16_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_image_to_column_nhwgc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F16, F16, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_image_to_column_f16_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_image_to_column_nhwgc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F32, F32, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_image_to_column_f32_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_image_to_column_nhwgc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, int8_t, int8_t, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_INT8
add_device_operation_instances(instances, device_image_to_column_i8_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
void add_device_image_to_column_nwgc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, BF16, BF16, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_BF16
add_device_operation_instances(instances, device_image_to_column_bf16_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
void add_device_image_to_column_nwgc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F16, F16, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_image_to_column_f16_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
void add_device_image_to_column_nwgc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F32, F32, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_image_to_column_f32_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
void add_device_image_to_column_nwgc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, int8_t, int8_t, ImageToColumn>>>&
instances)
{
#ifdef CK_ENABLE_INT8
add_device_operation_instances(instances, device_image_to_column_i8_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(DEVICE_NORMALIZATION_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_INSTANCES
device_layernorm2d_f16_instance.cpp
device_layernorm4d_f16_instance.cpp
device_groupnorm_f16_instance.cpp
device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
device_layernorm2d_f32_instance.cpp
device_layernorm4d_f32_instance.cpp
device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f32_instance.cpp)
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(DEVICE_NORMALIZATION_FWD_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_FWD_INSTANCES
device_layernorm2d_fwd_f16_instance.cpp
device_layernorm4d_fwd_f16_instance.cpp
device_groupnorm_fwd_f16_instance.cpp
device_groupnorm_fwd_swish_f16_instance.cpp
device_groupnorm_fwd_swish_f16_f32_f32_f16_instance.cpp
device_layernorm2d_fwd_f32_instance.cpp
device_layernorm4d_fwd_f32_instance.cpp
device_groupnorm_fwd_f32_instance.cpp
device_groupnorm_fwd_swish_f32_instance.cpp)
add_instance_library(device_normalization_fwd_instance ${DEVICE_NORMALIZATION_FWD_INSTANCES})
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_5_3_f32_instances( void add_device_normalization_fwd_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f32_instances( void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 2, 1>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_2_1_f32_instances( void add_device_normalization_fwd_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 2, 1>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 2, 1>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_4_3_f16_instances( void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 4, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_4_3_f32_instances( void add_device_normalization_fwd_rank_4_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 4, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 4, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
add_instance_library(device_transpose_instance
device_transpose_instances_3d.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_transpose_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 5>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_transpose_f16_instances{});
#else
ignore = instances;
#endif
}
void add_device_transpose_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 5>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_transpose_f32_instances{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -22,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} ...@@ -22,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
``` ```
## Profile 2d forward convolution kernels ## Profile 2D forward convolution kernels
```bash ```bash
#arg1: tensor operation (conv=Convolution) #arg1: tensor operation (conv=Convolution)
#arg2: data type (0=fp32, 1=fp16) #arg2: data type (0=fp32, 1=fp16)
...@@ -50,21 +50,23 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ...@@ -50,21 +50,23 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
## Profile contraction kernels ## Profile contraction kernels
```bash ```bash
#arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear)
#arg2: data type (0: fp32; 1: f64)\n" #arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)
#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; #arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)
#arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
#arg4: verification (0: no; 1: yes) #arg5: verification (0: no; 1: yes)
#arg5: initialization (0: no init; 1: integer value; 2: decimal value) #arg6: initialization (0: no init; 1: integer value; 2: decimal value)
#arg6: print tensor value (0: no; 1: yes) #arg7: print tensor value (0: no; 1: yes)
#arg7: time kernel (0: no, 1: yes) #arg8: time kernel (0: no, 1: yes)
#arg8 and arg9: alpha and beta #arg9: alpha
#arg10 to 15: M0, M1, N0, N1, K0, K1 #arg10: beta
#arg16 to 31: Strides for A, B, D and E (skip for default) #arg11 to 16: M0, M1, N0, N1, K0, K1
#arg17 to 32: Strides for A, B, D and E (skip for default)
################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ################ op datatype compute_datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
./bin/ckProfiler contraction_bilinear 0 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
``` ```
Result (MI100) Result (MI100)
...@@ -115,7 +117,7 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s ...@@ -115,7 +117,7 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s
# arg6: print tensor value (0: no; 1: yes) # arg6: print tensor value (0: no; 1: yes)
# arg7: time kernel (0: no, 1: yes) # arg7: time kernel (0: no, 1: yes)
# Following arguments (depending on number of spatial dims): # Following arguments (depending on number of spatial dims):
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) # Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D)
# G, N, K, C, # G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D) # <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D) # <input image spatial dimensions>, (ie Hi, Wi for 2D)
...@@ -147,7 +149,9 @@ GB/s: 127.947 ...@@ -147,7 +149,9 @@ GB/s: 127.947
# arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight) # arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight)
# arg2: data type (0: Input fp32, Weight fp32, Output fp32 # arg2: data type (0: Input fp32, Weight fp32, Output fp32
# 1: Input fp16, Weight fp16, Output fp16 # 1: Input fp16, Weight fp16, Output fp16
# 2: Input bf16, Weight fp32, Output bf16) # 2: Input bf16, Weight fp32, Output bf16
# 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8
# 4: Input int8, Weight int8, Output int8)
# arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo] # arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo]
# 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K] # 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
# 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K] # 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]
...@@ -156,7 +160,7 @@ GB/s: 127.947 ...@@ -156,7 +160,7 @@ GB/s: 127.947
# arg6: print tensor value (0: no; 1: yes) # arg6: print tensor value (0: no; 1: yes)
# arg7: time kernel (0: no, 1: yes) # arg7: time kernel (0: no, 1: yes)
# Following arguments (depending on number of spatial dims): # Following arguments (depending on number of spatial dims):
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) # Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D)
# G, N, K, C, # G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D) # <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D) # <input image spatial dimensions>, (ie Hi, Wi for 2D)
...@@ -167,7 +171,7 @@ GB/s: 127.947 ...@@ -167,7 +171,7 @@ GB/s: 127.947
# SplitK # SplitK
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
./bin/ckProfiler grouped_conv_bwd_weight 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 ./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
``` ```
...@@ -192,14 +196,15 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate ...@@ -192,14 +196,15 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate
# 1: Input fp16, Weight fp16, Output fp16 # 1: Input fp16, Weight fp16, Output fp16
# 2: Input bf16, Weight bf16, Output bf16 # 2: Input bf16, Weight bf16, Output bf16
# 3: Input int8, Weight int8, Output int8) # 3: Input int8, Weight int8, Output int8)
# arg3: tensor layout (0: Input[N, Hi, Wi, C], Output[N * Ho * Wo, Y * X * C]) # arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Output[G * N * Ho * Wo, Y * X * C],
# 1: Input[N, Hi, Wi, G, C], Output[N * Ho * Wo * G, Y * X * C])
# arg4: verification (0: no, 1: yes) # arg4: verification (0: no, 1: yes)
# arg5: initialization (0: no init, 1: integer value, 2: decimal value) # arg5: initialization (0: no init, 1: integer value, 2: decimal value)
# arg6: print tensor value (0: no; 1: yes) # arg6: print tensor value (0: no; 1: yes)
# arg7: time kernel (0: no, 1: yes) # arg7: time kernel (0: no, 1: yes)
# arg8: operation type (0: ImageToColumn, 1: ColumnToImage) # arg8: operation type (0: ImageToColumn, 1: ColumnToImage)
# Following arguments (depending on number of spatial dims): # Following arguments (depending on number of spatial dims):
# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) # Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D)
# G, N, K, C, # G, N, K, C,
# <filter spatial dimensions>, (ie Y, X for 2D) # <filter spatial dimensions>, (ie Y, X for 2D)
# <input image spatial dimensions>, (ie Hi, Wi for 2D) # <input image spatial dimensions>, (ie Hi, Wi for 2D)
......
...@@ -31,10 +31,14 @@ namespace profiler { ...@@ -31,10 +31,14 @@ namespace profiler {
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using F32 = float;
using F64 = double;
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename CDELayout,
typename DataType, typename DataType,
typename ComputeDataType,
typename DTupleDataType, typename DTupleDataType,
typename CDElementOp> typename CDElementOp>
int profile_contraction_impl(ck::index_t do_verification, int profile_contraction_impl(ck::index_t do_verification,
...@@ -45,10 +49,10 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -45,10 +49,10 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& M, const std::vector<ck::index_t>& M,
const std::vector<ck::index_t>& N, const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K, const std::vector<ck::index_t>& K,
const std::vector<ck::index_t>& StridesA, const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1]
const std::vector<ck::index_t>& StridesB, const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1]
const std::vector<ck::index_t>& StridesE, const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1]
const std::vector<ck::index_t>& StridesD) const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1]
{ {
bool pass = true; bool pass = true;
...@@ -63,13 +67,13 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -63,13 +67,13 @@ int profile_contraction_impl(ck::index_t do_verification,
}; };
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA)); Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor<DataType> b_n_k(f_host_tensor_descriptor(N, K, StridesB));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD)); Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_n_k: " << b_n_k.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
...@@ -78,12 +82,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -78,12 +82,12 @@ int profile_contraction_impl(ck::index_t do_verification,
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); b_n_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); b_n_k.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
} }
...@@ -91,12 +95,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -91,12 +95,12 @@ int profile_contraction_impl(ck::index_t do_verification,
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_n_k.mData.data());
e_device_buf.SetZero(); e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
...@@ -118,7 +122,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -118,7 +122,8 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDElementOp>; CDElementOp,
ComputeDataType>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -126,6 +131,9 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -126,6 +131,9 @@ int profile_contraction_impl(ck::index_t do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
using AccDataType =
typename std::conditional<std::is_same<ComputeDataType, F64>::value, F64, F32>::type;
// Run reference op // Run reference op
if(do_verification) if(do_verification)
{ {
...@@ -136,7 +144,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -136,7 +144,8 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
DataType, DataType,
DataType, DataType,
DataType, AccDataType,
ComputeDataType,
AElementOp, AElementOp,
BElementOp>; BElementOp>;
...@@ -146,7 +155,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -146,7 +155,7 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
auto ref_argument = auto ref_argument =
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -272,8 +281,29 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -272,8 +281,29 @@ int profile_contraction_impl(ck::index_t do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
float threshold = // Both the kernel and the reference use `AccDataType`, so an absolute error of both
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon(); // of them is bounded by `nelems_k * std::numeric_limits<AccDataType>::epsilon()`.
// Comparing one to another can result in an absolute error as high as twice that
// value.
double threshold = 2 * nelems_k * std::numeric_limits<AccDataType>::epsilon();
// Handle the possible casting error of either AccDataType -> DataType or
// DataType -> ComputeDataType.
// TODO: Add a generic solution for calculating thresholds in CK.
if constexpr(ck::is_same_v<DataType, ck::bhalf_t> ||
ck::is_same_v<ComputeDataType, ck::bhalf_t>)
{
const double epsilon = std::pow(2, -7);
// Maximum relative casting error when rounding to zero.
threshold += epsilon * 2;
}
else if constexpr(ck::is_same_v<DataType, ck::half_t> ||
ck::is_same_v<ComputeDataType, ck::half_t>)
{
const double epsilon = std::pow(2, -10);
// Maximum relative casting error when rounding to zero.
threshold += epsilon * 2;
}
pass = pass & ck::utils::check_err(e_m_n_device_result, pass = pass & ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result, e_m_n_host_result,
"Error: incorrect results!", "Error: incorrect results!",
...@@ -283,7 +313,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -283,7 +313,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_n_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
......
...@@ -23,8 +23,18 @@ enum struct ContractionMatrixLayout ...@@ -23,8 +23,18 @@ enum struct ContractionMatrixLayout
enum struct ContractionDataType enum struct ContractionDataType
{ {
F32_F32_F32_F32, // 0 F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1 F64_F64_F64_F64, // 1
F16_F16_F16_F16, // 2
BF16_BF16_BF16_BF16, // 3
};
enum struct ContractionComputeDataType
{
F32 = 0,
F64,
F16,
BF16,
}; };
inline void collect_index_params(char* argv[], inline void collect_index_params(char* argv[],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment