Commit 9b7093aa authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'develop' into lwpck-791

parents 4cb63955 bba085d2
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
endif()
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
endif()
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
endif()
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
......@@ -7,20 +7,114 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir
......
......@@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
M_raw_{M},
N_raw_{N},
K_raw_{K},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
......@@ -314,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
index_t M01_;
index_t N01_;
index_t M_raw_;
index_t N_raw_;
index_t K_raw_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -485,6 +492,50 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
// Make sure that the M, N, K dimensions before padding are divisible by respective vector
// lengths.
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto A_K_vec_length =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
if(arg.K_raw_ % A_K_vec_length != 0)
{
return false;
}
}
else
{
constexpr auto A_M_vec_lenght =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
if(arg.M_raw_ % A_M_vec_lenght != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
constexpr auto B_N_vec_lenght =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
if(arg.N_raw_ % B_N_vec_lenght != 0)
{
return false;
}
}
else
{
constexpr auto B_K_vec_length =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
if(arg.K_raw_ % B_K_vec_length != 0)
{
return false;
}
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
......
......@@ -5,6 +5,8 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck {
......@@ -242,4 +244,5 @@ __host__ __device__ Y cast_from_f8(X x)
}
} // namespace ck::utils
#endif
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
......@@ -85,6 +85,19 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -92,20 +105,33 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -113,14 +139,20 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#endif
}
#endif
......@@ -129,6 +161,19 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -136,20 +181,33 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -157,14 +215,20 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#endif
}
#endif
......@@ -234,30 +298,47 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
#endif
......@@ -266,21 +347,38 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
......@@ -290,6 +388,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
#endif
......
......@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType,
typename DxDataType,
typename DyDataType,
......@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
......@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType,
typename YDataType,
typename AccDataType,
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
......@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,38 +16,38 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
#endif
template <typename XDataType,
typename YDataType,
typename ScaleDataType,
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>)
......@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
......@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#if defined CK_ENABLE_FP8
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
......@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#if defined CK_ENABLE_FP8
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,6 +16,7 @@ namespace device {
namespace instance {
// conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
......@@ -29,7 +30,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
......@@ -43,7 +45,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
......@@ -57,7 +60,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
......@@ -71,7 +75,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
......@@ -85,7 +90,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
......@@ -99,8 +105,9 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
......@@ -114,7 +121,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
......@@ -128,7 +136,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
......@@ -142,7 +151,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
......@@ -156,7 +166,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
......@@ -170,7 +181,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
......@@ -184,7 +196,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
......@@ -230,42 +242,54 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
#endif
}
}
else if constexpr(NumDimSpatial == 3)
......@@ -274,46 +298,58 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
#endif
}
}
......
......@@ -2,13 +2,20 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
......@@ -18,39 +25,10 @@
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
......@@ -60,17 +38,38 @@
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
......
function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
#make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 0)
endif()
#message("add_instance_library returns ${result}")
return(PROPAGATE result)
endfunction(add_instance_library INSTANCE_NAME)
......@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
#message("fp8 instance found!")
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
#message("fp32 instance found!")
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
message("fp32 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
#message("fp64 instance found!")
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!")
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
#message("int8 instance found!")
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
#message("instance should be built for all types!")
set(add_inst 1)
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
NOT "${cmake_instance}" MATCHES "_f8" OR
NOT "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR
NOT DEFINED DTYPES)
message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
......
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp
device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp
device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
set(BATCHED_GEMM_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
endif()
\ No newline at end of file
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment