Commit 4100d1d8 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into migx-flash-attn

parents 48717006 c8a8385f
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
// FP16
void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
......@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances(
void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_normalization_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
......@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances(
void add_device_normalization_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
#endif
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
......@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
{
......@@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
......@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f32_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 4;
static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
// FP16
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
// FP32
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -22,38 +22,41 @@ static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef CK_ENABLE_FP16
// FP16
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, AvgOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, AvgOp, false>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
......@@ -61,6 +64,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>>
{
......@@ -69,36 +74,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
}
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
}
}
#endif
}
return op_ptrs;
......
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef DL_KERNELS
// Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
......@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
#endif
// Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
......@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
}
}
......@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
}
}
......@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
}
}
......@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
}
}
......@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
\ No newline at end of file
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector<
......@@ -64,7 +64,7 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
#endif
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -163,12 +163,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
......@@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if constexpr(is_same_v<Activation, TanH>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
}
}
......@@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector<
......@@ -63,7 +63,7 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
#endif
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -161,12 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
}
}
......@@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if constexpr(is_same_v<Activation, TanH>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
}
}
......@@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
PassThrough,
Activation_Mul2_Clamp<Relu>>>>&
instances);
#endif
void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
......@@ -128,12 +128,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
......@@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
PassThrough,
Activation_Mul_Clamp<Relu>>>>&
instances);
#endif
void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
......@@ -125,12 +125,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if constexpr(is_same_v<Activation, PassThrough>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs);
}
else if constexpr(is_same_v<Activation, Relu>)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs);
}
}
......@@ -144,3 +148,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise(
{
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
using cfg1 = remove_cvref_t<decltype(
std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
reduce_configuration_1_instances_blockwise{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance =
DeviceReduceMultiBlock<InDataType,
......
......@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_for<0,
std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value,
1>{}([&](auto i) {
using cfg1 = remove_cvref_t<decltype(
std::get<i.value>(reduce_configuration_1_instances_multiblock_atomic_add{}))>;
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
reduce_configuration_1_instances_multiblock_atomic_add{}))>;
static_for<0,
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_multiblock_atomic_add{}))>;
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
reduce_configuration_2_instances_multiblock_atomic_add{}))>;
using ReduceOpInstance = DeviceReduceMultiBlock<InDataType,
AccDataType,
......
......@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise(
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>;
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
reduce_configuration_2_instances_threadwise{}))>;
using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
AccDataType,
......
......@@ -9,64 +9,89 @@
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>&);
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>&);
void add_device_softmax_f32_f32_rank3_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>&);
void add_device_softmax_f32_f32_rank4_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>&);
void add_device_softmax_i8_i8_rank3_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3>>&);
void add_device_softmax_i8_i8_rank4_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4>>&);
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>>
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>>
{
using DeviceOp =
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
using DeviceOp = DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F16>)
{
if constexpr(Rank == 3)
add_device_softmax_f16_f16_rank3_instances(op_ptrs);
{
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank3_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank3_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank3_reduce3_instances(op_ptrs);
}
else if constexpr(Rank == 4)
add_device_softmax_f16_f16_rank4_instances(op_ptrs);
{
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs);
}
}
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>)
{
if constexpr(Rank == 3)
add_device_softmax_f32_f32_rank3_instances(op_ptrs);
{
if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank3_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f32_f32_rank3_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f32_f32_rank3_reduce3_instances(op_ptrs);
}
else if constexpr(Rank == 4)
add_device_softmax_f32_f32_rank4_instances(op_ptrs);
{
if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f32_f32_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f32_f32_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs);
}
}
else if constexpr(std::is_same_v<InDataType, I8> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, I8>)
{
if constexpr(Rank == 3)
add_device_softmax_i8_i8_rank3_instances(op_ptrs);
else if constexpr(Rank == 4)
add_device_softmax_i8_i8_rank4_instances(op_ptrs);
}
#endif
return op_ptrs;
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>>& instances);
} // namespace instance
} // namespace device
......
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>>& instances);
} // namespace instance
} // namespace device
......
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>>& instances);
} // namespace instance
} // namespace device
......
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank4_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 1>>& instances);
} // namespace instance
} // namespace device
......
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank4_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 2>>& instances);
} // namespace instance
} // namespace device
......
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank4_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 3>>& instances);
} // namespace instance
} // namespace device
......
......@@ -14,7 +14,7 @@ namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank4_reduce4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 4>>& instances);
} // namespace instance
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment