"docs/vscode:/vscode.git/clone" did not exist on "8516999495114926c9838c2d6e0feb580d4d983f"
Commit c9013009 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 114c2646 84dcf5d0
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
endif()
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
endif()
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
if(result EQUAL 0)
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
endif()
set(target 1)
endif()
endforeach()
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_permute example_permute_1xHxW_fp16)
endif()
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_permute example_permute_NxHxW_fp16)
endif()
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
if(result EQUAL 0)
add_dependencies(example_permute example_permute_HxWx4_fp16)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......@@ -11,7 +10,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()
if(DL_KERNELS)
# Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization
......@@ -24,5 +22,3 @@ endforeach()
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
endif()
endif()
\ No newline at end of file
......@@ -3,15 +3,9 @@ list(APPEND gpu_list2 gfx908 gfx90a)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
endif()
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
......@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
endforeach()
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
endif()
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
......@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct YElementOp
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
ck::is_same<T, ck::half_t>::value,
static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
ck::is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
T a;
static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
ck::is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
X a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a;
y = ck::type_convert<Y>(x * a);
};
};
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
endif()
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
endif()
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
if(DL_KERNELS)
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
endif()
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
endif()
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
endif()
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
endif()
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
endif()
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
......@@ -7,20 +7,114 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir
......
......@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
......@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
......
......@@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
M_raw_{M},
N_raw_{N},
K_raw_{K},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
......@@ -314,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
index_t M01_;
index_t N01_;
index_t M_raw_;
index_t N_raw_;
index_t K_raw_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -485,6 +492,50 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
// Make sure that the M, N, K dimensions before padding are divisible by respective vector
// lengths.
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto A_K_vec_length =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
if(arg.K_raw_ % A_K_vec_length != 0)
{
return false;
}
}
else
{
constexpr auto A_M_vec_lenght =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
if(arg.M_raw_ % A_M_vec_lenght != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
constexpr auto B_N_vec_lenght =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
if(arg.N_raw_ % B_N_vec_lenght != 0)
{
return false;
}
}
else
{
constexpr auto B_K_vec_length =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
if(arg.K_raw_ % B_K_vec_length != 0)
{
return false;
}
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
......
......@@ -144,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
......@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
ComputeDataType,
AccDataType,
......
......@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -72,6 +73,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......@@ -96,9 +100,23 @@ __global__ void
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
#endif
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
......@@ -134,29 +152,46 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
: public DeviceGroupedConvBwdWeight<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl;
// 1d
static constexpr bool is_NWGK_GKXC_NWGC =
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
static constexpr bool is_GNWK_GKXC_GNWC =
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
// 2d
static constexpr bool is_NHWGK_GKYXC_NHWGC =
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
static constexpr bool is_GNHWK_GKYXC_GNHWC =
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
using ADataType = OutDataType;
using BDataType = InDataType;
......@@ -176,6 +211,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto spatial_offset = I3;
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
......@@ -195,12 +232,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -209,90 +246,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using namespace ck;
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t N = a_g_n_c_wis_lengths[I1];
const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset];
const index_t InLeftPadW = input_left_pads[I0];
const index_t InRightPadW = input_right_pads[I0];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t ConvDilationW = conv_filter_dilations[I0];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C));
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weights tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
else
{
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -321,38 +370,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -361,103 +415,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using namespace ck;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t N = a_g_n_c_wis_lengths[I1];
const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t InLeftPadH = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[I1];
const index_t InRightPadH = input_right_pads[I0];
const index_t InRightPadW = input_right_pads[I1];
const index_t ConvStrideH = conv_filter_strides[I0];
const index_t ConvStrideW = conv_filter_strides[I1];
const index_t ConvDilationH = conv_filter_dilations[I0];
const index_t ConvDilationW = conv_filter_dilations[I1];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
else
{
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -488,39 +550,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -529,110 +596,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using namespace ck;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t N = a_g_n_c_wis_lengths[I1];
const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
const index_t InLeftPadD = input_left_pads[I0];
const index_t InLeftPadH = input_left_pads[I1];
const index_t InLeftPadW = input_left_pads[I2];
const index_t InRightPadD = input_right_pads[I0];
const index_t InRightPadH = input_right_pads[I1];
const index_t InRightPadW = input_right_pads[I2];
const index_t ConvStrideD = conv_filter_strides[I0];
const index_t ConvStrideH = conv_filter_strides[I1];
const index_t ConvStrideW = conv_filter_strides[I2];
const index_t ConvDilationD = conv_filter_dilations[I0];
const index_t ConvDilationH = conv_filter_dilations[I1];
const index_t ConvDilationW = conv_filter_dilations[I2];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
else
{
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -672,27 +749,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
} // function end
......@@ -701,22 +783,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
{1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
1,
1,
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
......@@ -785,11 +867,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -809,38 +891,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{},
filter_spatial_lengths_{},
output_spatial_lengths_{},
Conv_G_{a_g_n_c_wis_lengths[I0]},
Conv_K_{b_g_k_c_xs_lengths[I1]},
Conv_C_{a_g_n_c_wis_lengths[I2]},
filter_lengths_{b_g_k_c_xs_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -863,24 +931,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
}
const ADataType* p_a_grid_;
......@@ -908,13 +961,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// for checking IsSupportedArgument()
const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_;
const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
......@@ -1036,10 +1086,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
// DL version only supports split_k equal to 1
if(arg.k_batch_ != 1)
return false;
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
{
return false;
}
......@@ -1050,8 +1104,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
{
return false;
}
......@@ -1206,7 +1261,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
str << "DeviceGroupedConvBwdWeight_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -193,6 +193,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout,
......@@ -217,6 +218,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeType,
AccDataType,
CShuffleDataType,
DsDataType,
......
......@@ -27,6 +27,12 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -69,18 +75,36 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<float, half_t>(float& y, const half_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
......@@ -89,6 +113,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
......@@ -118,6 +143,7 @@ struct PassThrough
{
y = type_convert<f8_t>(x);
}
#endif
};
struct UnaryConvert
......@@ -146,6 +172,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
......@@ -162,6 +189,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x);
}
};
#endif
struct Scale
{
......@@ -412,14 +440,19 @@ struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
};
float beta_ = 1.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment