Unverified Commit 5c75ff5c authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into ab_blk_copy_multi_d

parents 96417775 48ba6e8a
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
endif()
...@@ -7,20 +7,114 @@ add_custom_target(examples) ...@@ -7,20 +7,114 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) set(result 1)
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) if(DEFINED DTYPES)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN}) foreach(source IN LISTS FILE_NAME)
add_dependencies(examples ${EXAMPLE_NAME}) set(test 0)
add_dependencies(check ${EXAMPLE_NAME}) foreach(type IN LISTS DTYPES)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) set(result 1)
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) if(DEFINED DTYPES)
add_dependencies(examples ${EXAMPLE_NAME}) foreach(source IN LISTS FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable_no_testing EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir # add all example subdir
......
...@@ -9,13 +9,13 @@ namespace ck { ...@@ -9,13 +9,13 @@ namespace ck {
struct GridwiseGemmPipeline_v2 struct GridwiseGemmPipeline_v2
{ {
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(const index_t num_loop)
{ {
// TODO: improve applicability // TODO: improve applicability
return num_loop % 2 == 0; return num_loop % 2 == 0;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop)
{ {
return (num_loop / 2) > 1; return (num_loop / 2) > 1;
} }
......
...@@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
} }
__host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); }
// Argument // Argument
struct Problem struct Problem
...@@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
"Invalid tuning param!"); "Invalid tuning param!");
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = problem.K / K1Value; const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
return false; return false;
...@@ -1026,8 +1024,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -1026,8 +1024,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = problem.K / K1; const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
......
...@@ -16,26 +16,26 @@ namespace tensor_operation { ...@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances( void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances( void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances( void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances( void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename DxDataType, typename DxDataType,
typename DyDataType, typename DyDataType,
...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory< ...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> && if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> && is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
...@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory< ...@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> && #endif
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> && if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<MeanVarDataType, F32>) is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> && #endif
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && #ifdef CK_ENABLE_BF16
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> && if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<MeanVarDataType, F32>) is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> && #endif
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> && if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<MeanVarDataType, F64>) is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,26 +16,26 @@ namespace tensor_operation { ...@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances( void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances( void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances( void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances( void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory< ...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> && is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>) is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory< ...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> && #endif
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> && #endif
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> && #ifdef CK_ENABLE_BF16
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> && #endif
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>) if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,38 +16,38 @@ namespace tensor_operation { ...@@ -16,38 +16,38 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_infer_rank_4_f16_instances( void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>, ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>, ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_infer_rank_4_f32_instances( void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>, ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>, ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_infer_rank_4_bf16_instances( void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>, ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>, ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_infer_rank_4_f64_instances( void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>, ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>, ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename YDataType, typename YDataType,
typename ScaleDataType, typename ScaleDataType,
...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen ...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> && is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>) is_same_v<MeanVarDataType, F32>)
...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen ...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs); add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> && #endif
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4) if constexpr(Rank == 4)
{ {
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs); add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> && #endif
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> && #ifdef CK_ENABLE_BF16
is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4) if constexpr(Rank == 4)
{ {
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs); add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> && #endif
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<MeanVarDataType, F64>) if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4) if constexpr(Rank == 4)
{ {
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs); add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -312,6 +312,23 @@ void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( ...@@ -312,6 +312,23 @@ void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
#ifdef CK_ENABLE_FP8
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -505,6 +522,32 @@ struct DeviceOperationInstanceFactory< ...@@ -505,6 +522,32 @@ struct DeviceOperationInstanceFactory<
#endif #endif
} }
} }
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<ADataType, ck::f8_t> && is_same_v<BDataType, ck::f8_t> &&
is_same_v<CDataType, ck::f8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(op_ptrs);
}
}
#endif #endif
return op_ptrs; return op_ptrs;
} }
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row, Row,
...@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance ...@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row, Row,
...@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances( ...@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
// GEMM + Bilinear // GEMM + Bilinear
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>) is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{ {
...@@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs); op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> && #endif
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>) #ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>) is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
...@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs); add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
...@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( ...@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
...@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( ...@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#if defined CK_ENABLE_FP8 #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory< ...@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>) is_same_v<CDataType, float>)
{ {
...@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory< ...@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
} }
} }
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
...@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory< ...@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
} }
#if defined CK_ENABLE_FP8 #endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> && else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
...@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory< ...@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory<
} }
} }
#endif #endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,6 +16,7 @@ namespace device { ...@@ -16,6 +16,7 @@ namespace device {
namespace instance { namespace instance {
// conv2d backward data // conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
...@@ -29,7 +30,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( ...@@ -29,7 +30,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
...@@ -43,7 +45,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( ...@@ -43,7 +45,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
...@@ -57,7 +60,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( ...@@ -57,7 +60,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
...@@ -71,7 +75,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( ...@@ -71,7 +75,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
...@@ -85,7 +90,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( ...@@ -85,7 +90,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
...@@ -99,8 +105,9 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( ...@@ -99,8 +105,9 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv3d backward data // conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK, GNDHWK,
...@@ -114,7 +121,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( ...@@ -114,7 +121,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK, GNDHWK,
...@@ -128,7 +136,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( ...@@ -128,7 +136,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK, GNDHWK,
...@@ -142,7 +151,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( ...@@ -142,7 +151,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK, NDHWGK,
...@@ -156,7 +166,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( ...@@ -156,7 +166,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK, NDHWGK,
...@@ -170,7 +181,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( ...@@ -170,7 +181,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK, NDHWGK,
...@@ -184,7 +196,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( ...@@ -184,7 +196,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename OutLayout, typename OutLayout,
typename WeiLayout, typename WeiLayout,
...@@ -230,42 +242,54 @@ struct DeviceOperationInstanceFactory< ...@@ -230,42 +242,54 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> && if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, GNHWK>)
{ {
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) is_same_v<OutDataType, F16>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> && else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>) is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>) is_same_v<OutDataType, BF16>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
} }
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> && else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>) is_same_v<OutLayout, NHWGK>)
{ {
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) is_same_v<OutDataType, F16>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> && else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>) is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>) is_same_v<OutDataType, BF16>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
} }
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
...@@ -274,46 +298,58 @@ struct DeviceOperationInstanceFactory< ...@@ -274,46 +298,58 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> && if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>) is_same_v<OutLayout, GNDHWK>)
{ {
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) is_same_v<OutDataType, F16>)
{ {
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs); op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> && else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>) is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs); op_ptrs);
} }
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>) is_same_v<OutDataType, BF16>)
{ {
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
} }
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> && else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>) is_same_v<OutLayout, NDHWGK>)
{ {
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) is_same_v<OutDataType, F16>)
{ {
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs); op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> && else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>) is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs); op_ptrs);
} }
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>) is_same_v<OutDataType, BF16>)
{ {
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
} }
} }
......
...@@ -2,13 +2,20 @@ ...@@ -2,13 +2,20 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
...@@ -18,39 +25,10 @@ ...@@ -18,39 +25,10 @@
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
...@@ -60,17 +38,38 @@ ...@@ -60,17 +38,38 @@
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
......
function(add_instance_library INSTANCE_NAME) function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}") message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN}) set(result 1)
target_compile_features(${INSTANCE_NAME} PUBLIC) if(DEFINED DTYPES)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) foreach(source IN LISTS ARGN)
clang_tidy_check(${INSTANCE_NAME}) set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
#make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 0)
endif()
#message("add_instance_library returns ${result}")
return(PROPAGATE result)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
...@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}") ...@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance) set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance) file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0) set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
#message("fp8 instance found!") message("fp8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
#message("fp16 instance found!") message("fp16 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
#message("fp32 instance found!") message("fp32 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
#message("fp64 instance found!") message("fp64 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!") message("bf16 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
#message("int8 instance found!") message("int8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES) if(NOT "${cmake_instance}" MATCHES "_fp8" OR
#message("instance should be built for all types!") NOT "${cmake_instance}" MATCHES "_f8" OR
set(add_inst 1) NOT "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR
NOT DEFINED DTYPES)
message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
message("quantization instances will not be built!")
set(add_inst 0)
endif() endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS) if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.") message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
......
set(DEVICE_AVGPOOL_BWD_INSTANCES) set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp) device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp
endif() device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES}) add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
set(BATCHED_GEMM_INSTANCES) set(BATCHED_GEMM_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp) device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
endif() device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp) device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
endif() device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp) device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
endif() device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp) device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES}) add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment