Commit 84dcf5d0 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 705d5a08 c9553832
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data) add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
endif()
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(example_permute)
add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_permute example_permute_1xHxW_fp16) add_dependencies(example_permute example_permute_1xHxW_fp16)
endif()
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_permute example_permute_NxHxW_fp16) add_dependencies(example_permute example_permute_NxHxW_fp16)
endif()
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_permute example_permute_HxWx4_fp16) add_dependencies(example_permute example_permute_HxWx4_fp16)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -11,7 +10,6 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -11,7 +10,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endforeach() endforeach()
if(DL_KERNELS)
# Conv perlayer quantization # Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization # Conv perchannel quantization
...@@ -24,5 +22,3 @@ endforeach() ...@@ -24,5 +22,3 @@ endforeach()
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization # Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
endif()
endif()
\ No newline at end of file
...@@ -3,15 +3,9 @@ list(APPEND gpu_list2 gfx908 gfx90a) ...@@ -3,15 +3,9 @@ list(APPEND gpu_list2 gfx908 gfx90a)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0) if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
endif() add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
...@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
endforeach() endforeach()
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
endif()
...@@ -14,18 +14,22 @@ using ComputeDataType = float; ...@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct YElementOp struct YElementOp
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value || static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
ck::is_same<T, ck::half_t>::value, ck::is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T a; static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
ck::is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
X a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x); ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a; y = ck::type_convert<Y>(x * a);
}; };
}; };
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
if(DL_KERNELS) add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
endif()
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
endif() add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
endif()
...@@ -7,20 +7,114 @@ add_custom_target(examples) ...@@ -7,20 +7,114 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) set(result 1)
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) if(DEFINED DTYPES)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN}) foreach(source IN LISTS FILE_NAME)
add_dependencies(examples ${EXAMPLE_NAME}) set(test 0)
add_dependencies(check ${EXAMPLE_NAME}) foreach(type IN LISTS DTYPES)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) set(result 1)
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) if(DEFINED DTYPES)
add_dependencies(examples ${EXAMPLE_NAME}) foreach(source IN LISTS FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable_no_testing EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir # add all example subdir
......
...@@ -43,6 +43,9 @@ ...@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8 #ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON" #define CK_ENABLE_FP8 "ON"
#endif #endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16 #ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON" #define CK_ENABLE_FP16 "ON"
#endif #endif
...@@ -66,6 +69,10 @@ ...@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif #endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16 #ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif #endif
......
...@@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
M_raw_{M},
N_raw_{N},
K_raw_{K},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
...@@ -314,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -314,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
index_t M_raw_;
index_t N_raw_;
index_t K_raw_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -485,6 +492,50 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -485,6 +492,50 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// Make sure that the M, N, K dimensions before padding are divisible by respective vector
// lengths.
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto A_K_vec_length =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
if(arg.K_raw_ % A_K_vec_length != 0)
{
return false;
}
}
else
{
constexpr auto A_M_vec_lenght =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
if(arg.M_raw_ % A_M_vec_lenght != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
constexpr auto B_N_vec_lenght =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
if(arg.N_raw_ % B_N_vec_lenght != 0)
{
return false;
}
}
else
{
constexpr auto B_K_vec_length =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
if(arg.K_raw_ % B_K_vec_length != 0)
{
return false;
}
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102")
......
...@@ -144,7 +144,8 @@ template <typename ALayout, ...@@ -144,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -72,6 +73,9 @@ __global__ void ...@@ -72,6 +73,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -96,9 +100,23 @@ __global__ void ...@@ -96,9 +100,23 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
#endif
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -134,29 +152,46 @@ template <ck::index_t NDimSpatial, ...@@ -134,29 +152,46 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpatial,
: public DeviceGroupedConvBwdWeight< InLayout,
NDimSpatial, WeiLayout,
ck::tuple_element_t<NDimSpatial - 1, OutLayout,
ck::Tuple<ck::tensor_layout::convolution::GNWC, InDataType,
ck::tensor_layout::convolution::GNHWC, WeiDataType,
ck::tensor_layout::convolution::GNDHWC>>, OutDataType,
ck::tuple_element_t<NDimSpatial - 1, InElementwiseOperation,
ck::Tuple<ck::tensor_layout::convolution::GKXC, WeiElementwiseOperation,
ck::tensor_layout::convolution::GKYXC, OutElementwiseOperation>
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; // 1d
static constexpr bool is_NWGK_GKXC_NWGC =
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
static constexpr bool is_GNWK_GKXC_GNWC =
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
// 2d
static constexpr bool is_NHWGK_GKYXC_NHWGC =
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
static constexpr bool is_GNHWK_GKYXC_GNHWC =
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -176,6 +211,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -176,6 +211,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto spatial_offset = I3;
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
...@@ -195,12 +232,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,12 +232,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -209,90 +246,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -209,90 +246,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Wi = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Wo = output_spatial_lengths[0]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t X = filter_spatial_lengths[0]; const index_t C = a_g_n_c_wis_lengths[I2];
const index_t InLeftPadW = input_left_pads[0]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
const index_t InRightPadW = input_right_pads[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
const index_t ConvStrideW = conv_filter_strides[0]; const index_t X = b_g_k_c_xs_lengths[spatial_offset];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[I0];
const index_t InRightPadW = input_right_pads[I0];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t ConvDilationW = conv_filter_dilations[I0];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
const index_t GemmKTotal = N * Wo; const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weights tensor // C: weights tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_wi_c_grid_desc = const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -321,38 +370,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -321,38 +370,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -361,103 +415,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -361,103 +415,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Hi = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Wi = input_spatial_lengths[1]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Ho = output_spatial_lengths[0]; const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wo = output_spatial_lengths[1]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
const index_t Y = filter_spatial_lengths[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t X = filter_spatial_lengths[1]; const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1]; const index_t InLeftPadH = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[I1];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = input_right_pads[I0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = input_right_pads[I1];
const index_t ConvStrideH = conv_filter_strides[I0];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[I1];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvDilationH = conv_filter_dilations[I0];
const index_t ConvDilationW = conv_filter_dilations[I1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1]; const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -488,39 +550,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -488,39 +550,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -529,110 +596,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -529,110 +596,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Di = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Hi = input_spatial_lengths[1]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t Wi = input_spatial_lengths[2]; const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
const index_t Do = output_spatial_lengths[0]; const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = output_spatial_lengths[1]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
const index_t Wo = output_spatial_lengths[2]; const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Z = filter_spatial_lengths[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
const index_t Y = filter_spatial_lengths[1]; const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
const index_t X = filter_spatial_lengths[2]; const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadD = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[2]; const index_t InLeftPadH = input_left_pads[I1];
const index_t InLeftPadW = input_left_pads[I2];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadD = input_right_pads[I0];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = input_right_pads[I1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = input_right_pads[I2];
const index_t ConvStrideD = conv_filter_strides[I0];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideH = conv_filter_strides[I1];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[I2];
const index_t ConvStrideW = conv_filter_strides[2]; const index_t ConvDilationD = conv_filter_dilations[I0];
const index_t ConvDilationH = conv_filter_dilations[I1];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationW = conv_filter_dilations[I2];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2]; const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
const index_t GemmKTotal = N * Do * Ho * Wo; const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_tuple(N, Di, Hi, Wi, C),
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -672,27 +749,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -672,27 +749,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
...@@ -701,22 +783,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -701,22 +783,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1},
1, {1, 1, 1},
1, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
...@@ -785,11 +867,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -785,11 +867,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -809,38 +891,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -809,38 +891,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{a_g_n_c_wis_lengths[0]}, Conv_G_{a_g_n_c_wis_lengths[I0]},
Conv_N_{a_g_n_c_wis_lengths[1]}, Conv_K_{b_g_k_c_xs_lengths[I1]},
Conv_K_{b_g_k_c_xs_lengths[1]}, Conv_C_{a_g_n_c_wis_lengths[I2]},
Conv_C_{a_g_n_c_wis_lengths[2]}, filter_lengths_{b_g_k_c_xs_lengths},
input_spatial_lengths_{},
filter_spatial_lengths_{},
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, a_g_n_c_wis_lengths, // input
Conv_K_, a_g_n_c_wis_strides,
Conv_C_, b_g_k_c_xs_lengths, // weight
input_spatial_lengths_, b_g_k_c_xs_strides,
filter_spatial_lengths_, e_g_n_k_wos_lengths, // output
output_spatial_lengths_, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -863,24 +931,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -863,24 +931,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
Conv_N_ * Conv_K_ * compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
std::accumulate(begin(output_spatial_lengths_), compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
end(output_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -908,13 +961,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -908,13 +961,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
const index_t Conv_G_; const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1036,10 +1086,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1036,10 +1086,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || // DL version only supports split_k equal to 1
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(arg.k_batch_ != 1)
ck::get_device_name() == "gfx1102")) return false;
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
{ {
return false; return false;
} }
...@@ -1050,8 +1104,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1050,8 +1104,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++) for(int i = 0; i < NDimSpatial; i++)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
{ {
return false; return false;
} }
...@@ -1206,7 +1261,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1206,7 +1261,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" str << "DeviceGroupedConvBwdWeight_Dl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -193,6 +193,7 @@ template <typename ALayout, ...@@ -193,6 +193,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout, BLayout,
...@@ -217,6 +218,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -217,6 +218,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
......
...@@ -27,6 +27,12 @@ struct PassThrough ...@@ -27,6 +27,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -69,18 +75,36 @@ struct PassThrough ...@@ -69,18 +75,36 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<float, half_t>(float& y, const half_t& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const __host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{ {
y = type_convert<int8_t>(x); y = type_convert<int8_t>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const __host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
...@@ -89,6 +113,7 @@ struct PassThrough ...@@ -89,6 +113,7 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -118,6 +143,7 @@ struct PassThrough ...@@ -118,6 +143,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -146,6 +172,7 @@ struct ConvertBF16RTN ...@@ -146,6 +172,7 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -162,6 +189,7 @@ struct ConvertF8SR ...@@ -162,6 +189,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
...@@ -412,14 +440,19 @@ struct Swish ...@@ -412,14 +440,19 @@ struct Swish
{ {
Swish(float beta = 1.0f) : beta_(beta) {} Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<T, ck::half_t>::value, is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x)); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; };
float beta_ = 1.0f; float beta_ = 1.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment