Commit b2ba0a69 authored by Jing Zhang's avatar Jing Zhang
Browse files

Merge remote-tracking branch 'origin/develop' into grouped_gemm_dev_args

parents 48d356fc 49180fd6
...@@ -5,6 +5,31 @@ project(composable_kernel) ...@@ -5,6 +5,31 @@ project(composable_kernel)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
if (DTYPES)
add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8")
add_definitions(-D__int8__)
endif()
if (DTYPES MATCHES "fp8")
add_definitions(-D__fp8__)
endif()
if (DTYPES MATCHES "fp16")
add_definitions(-D__fp16__)
endif()
if (DTYPES MATCHES "fp32")
add_definitions(-D__fp32__)
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-D__fp64__)
endif()
if (DTYPES MATCHES "bf16")
add_definitions(-D__bf16__)
endif()
message("DTYPES macro set to ${DTYPES}")
else()
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__)
endif()
enable_testing() enable_testing()
set(ROCM_SYMLINK_LIBS OFF) set(ROCM_SYMLINK_LIBS OFF)
...@@ -16,11 +41,24 @@ include(ROCMSetupVersion) ...@@ -16,11 +41,24 @@ include(ROCMSetupVersion)
include(ROCMInstallSymlinks) include(ROCMInstallSymlinks)
include(ROCMCreatePackage) include(ROCMCreatePackage)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
include(ROCMCheckTargetIds)
rocm_setup_version(VERSION 0.2.0) rocm_setup_version(VERSION 0.2.0)
include(TargetFlags) include(TargetFlags)
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
message("GPU_TARGETS= ${GPU_TARGETS}")
message("checking which targets are supported")
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
#These targets will be filtered and only supported ones will be used
#Setting GPU_TARGETS on command line will override this list
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " ")
find_package(hip)
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
......
...@@ -749,6 +749,22 @@ pipeline { ...@@ -749,6 +749,22 @@ pipeline {
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
} }
} }
stage("Build CK and run Tests on Navi32")
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() }
}
agent{ label rocmnode("navi32") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
}
}
} }
} }
......
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations)
...@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable ...@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
add_executable(client_gemm_quantization gemm_quantization.cpp) add_executable(client_gemm_quantization gemm_quantization.cpp)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations) target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations)
endif()
...@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15) ...@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15)
project(ck_app) project(ck_app)
add_compile_options(-std=c++17) add_compile_options(-std=c++17)
if (DTYPES)
add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8")
add_definitions(-D__int8__)
endif()
if (DTYPES MATCHES "fp8")
add_definitions(-D__fp8__)
endif()
if (DTYPES MATCHES "fp16")
add_definitions(-D__fp16__)
endif()
if (DTYPES MATCHES "fp32")
add_definitions(-D__fp32__)
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-D__fp64__)
endif()
if (DTYPES MATCHES "bf16")
add_definitions(-D__bf16__)
endif()
message("DTYPES macro set to ${DTYPES}")
else()
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__)
endif()
find_package(composable_kernel 1.0.0 COMPONENTS device_operations) find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
find_package(hip REQUIRED PATHS /opt/rocm) find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
......
...@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl) ...@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_fp32) add_dependencies(example_gemm_dl example_gemm_dl_fp32)
add_dependencies(example_gemm_dl example_gemm_dl_fp16) add_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_dependencies(example_gemm_dl example_gemm_dl_int8)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
...@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl) ...@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_dependencies(example_gemm_xdl example_gemm_xdl_int4)
......
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
# dlops # dlops
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
...@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file endif()
\ No newline at end of file
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
...@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i ...@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
# Conv + bias + tanh perchannel quantization # Conv + bias + tanh perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
\ No newline at end of file endif()
\ No newline at end of file
...@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat ...@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch(s) switch(s)
{ {
case ConvolutionBackwardDataSpecialization::Default: return "Default"; case ConvolutionBackwardDataSpecialization::Default: return "Default";
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
return "FFilter1x1Stride1Pad0";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
......
...@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEElementwiseOp> CDEElementwiseOp>
{ {
// FIXME // FIXME
static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now"); static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
...@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
}); });
static constexpr auto NonSpatialDimsNum = Number<3>{};
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
// problem definition // problem definition
const index_t Y = b_g_k_c_xs_lengths[3]; const index_t Z = b_g_k_c_xs_lengths[ZIdx];
const index_t X = b_g_k_c_xs_lengths[4]; const index_t Y = b_g_k_c_xs_lengths[YIdx];
const index_t X = b_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideH = conv_filter_strides_[0]; const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides_[1]; const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations_[0]; const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations_[1]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{ {
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0) for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{ {
continue; // check slice is valid
} const auto ZDotSlice =
NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
const auto a_grid_desc_ak0_m_ak1 = const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>( const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, if(YDotSlice * XDotSlice * ZDotSlice <= 0)
b_g_k_c_xs_lengths, {
b_g_k_c_xs_strides, continue;
e_g_n_c_wis_lengths, }
e_g_n_c_wis_strides,
conv_filter_strides, std::array<index_t, NDimSpatial> tildes;
conv_filter_dilations, if constexpr(NDimSpatial == 2)
input_left_pads, {
input_right_pads, tildes = {i_ytilde, i_xtilde};
{i_ytilde, i_xtilde}); }
else if constexpr(NDimSpatial == 3)
const auto b_grid_desc_bk0_n_bk1 = {
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>( tildes = {i_ztilde, i_ytilde, i_xtilde};
a_g_n_k_wos_lengths, }
a_g_n_k_wos_strides, else
b_g_k_c_xs_lengths, {
b_g_k_c_xs_strides, throw std::runtime_error("wrong! only implemented for 2D and 3D now");
e_g_n_c_wis_lengths, }
e_g_n_c_wis_strides,
conv_filter_strides, const auto a_grid_desc_ak0_m_ak1 =
conv_filter_dilations, transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths, a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, a_g_n_k_wos_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i], e_g_n_c_wis_lengths,
ds_g_n_c_wis_strides[i], e_g_n_c_wis_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{i_ytilde, i_xtilde}); tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
// desc for problem definition
const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
e_grid_desc_m_n)); a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
// desc for problem definition
const auto a_grid_desc_m_k =
transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k =
transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
}
} }
} }
} }
...@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for A matrix from global memory to LDS // vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> || if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>) is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{ {
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{ {
...@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// vector load for B matrix from global memory to LDS // vector load for B matrix from global memory to LDS
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>) if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{ {
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
{ {
...@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> || if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> || is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> || is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<DLayout, tensor_layout::convolution::GC> || is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>) is_same_v<DLayout, tensor_layout::convolution::G_C>)
...@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector store for E // vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> || if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>) is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
{ {
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
......
...@@ -18,32 +18,53 @@ template < ...@@ -18,32 +18,53 @@ template <
index_t NDimSpatial, index_t NDimSpatial,
typename ALayout, typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
constexpr auto constexpr auto make_out_grid_desc(const index_t N,
make_out_n_ho_wo_k_grid_desc(const index_t N, const index_t Do,
const index_t Ho, const index_t Ho,
const index_t Wo, const index_t Wo,
const index_t K, const index_t K,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides)
{ {
const auto KStride = Number<1>{};
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>) if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{ {
const index_t NStride = out_g_n_k_wos_strides[1]; const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3]; const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4]; const index_t WiStride = out_g_n_k_wos_strides[4];
const auto CStride = Number<1>{};
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WiStride, CStride)); make_tuple(WiStride, KStride));
} }
else else
{ {
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
make_tuple(NStride, HiStride, WiStride, CStride)); make_tuple(NStride, HiStride, WiStride, KStride));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t DoStride = out_g_n_k_wos_strides[3];
const index_t HoStride = out_g_n_k_wos_strides[4];
const index_t WoStride = out_g_n_k_wos_strides[5];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Do, Ho, Wo, K),
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
} }
} }
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>) else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
...@@ -60,12 +81,80 @@ make_out_n_ho_wo_k_grid_desc(const index_t N, ...@@ -60,12 +81,80 @@ make_out_n_ho_wo_k_grid_desc(const index_t N,
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
} }
} }
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
}
}
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
} }
} }
template <typename BLayout>
constexpr auto make_wei_grid_desc(
const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
{
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
template <index_t NDimSpatial, typename CLayout>
constexpr auto make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides)
{
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
{
return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[5],
in_g_n_c_wis_strides[2]));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
} // namespace } // namespace
template < template <
...@@ -82,10 +171,26 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -82,10 +171,26 @@ struct TransformConvBwdDataToGemm_v1
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto NonSpatialDimsNum = Number<3>{};
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ALayout, template <typename ALayout,
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> || (is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>), is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>),
bool>::type = false> bool>::type = false>
static auto MakeADescriptor_AK0_M_AK1( static auto MakeADescriptor_AK0_M_AK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
...@@ -100,35 +205,43 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -100,35 +205,43 @@ struct TransformConvBwdDataToGemm_v1
const std::array<index_t, NDimSpatial>& /* input_right_pads */, const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes) const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ytilde = tildes[0]; index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[1]; index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1]; const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1]; const index_t K = wei_g_k_c_xs_lengths[1];
const index_t Hi = in_g_n_c_wis_lengths[3]; const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Wi = in_g_n_c_wis_lengths[4]; const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Ho = out_g_n_k_wos_lengths[3]; const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Wo = out_g_n_k_wos_lengths[4]; const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Y = wei_g_k_c_xs_lengths[3]; const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t X = wei_g_k_c_xs_lengths[4]; const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadH = input_left_pads[0]; const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[1]; const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[1]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t AK0 = K / AK1; const index_t AK0 = K / AK1;
const auto out_n_ho_wo_k_grid_desc = // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
make_out_n_ho_wo_k_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>( const auto out_grid_desc =
N, Ho, Wo, K, out_g_n_k_wos_strides); make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
...@@ -136,8 +249,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -136,8 +249,8 @@ struct TransformConvBwdDataToGemm_v1
{ {
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N * Ho * Wo), make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(AK0, AK1))), make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
...@@ -152,103 +265,208 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -152,103 +265,208 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde); const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde); const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde = const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor if constexpr(NDimSpatial == 2)
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( {
out_n_ho_wo_k_grid_desc, // A: output tensor
make_tuple(make_pass_through_transform(N), const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
make_pad_transform(Ho, I0, I0), out_grid_desc,
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice), make_pad_transform(Ho, I0, I0),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_pad_transform(Wo, I0, I0),
make_slice_transform(XDot, I0, XDotSlice), make_pass_through_transform(K)),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_unmerge_transform(make_tuple(AK0, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{},
Sequence<1>{}, const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
Sequence<2>{}, out_n_hop_wop_k_grid_desc,
Sequence<3>{}, make_tuple(
Sequence<4>{}, make_pass_through_transform(N),
Sequence<5>{}), make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(Sequence<0>{}, make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
Sequence<1>{}, make_embed_transform(make_tuple(XDot, WTilde),
Sequence<2>{}, make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
Sequence<3>{}, make_pass_through_transform(K)),
Sequence<4>{}, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
Sequence<5, 6>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, transform_tensor_descriptor(
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(AK1)), make_slice_transform(YDot, I0, YDotSlice),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
const auto out_gemmak0_gemmm_gemmak1_grid_desc = make_unmerge_transform(make_tuple(AK0, AK1))),
ck::tensor_operation::device::PadTensorDescriptor( make_tuple(Sequence<0>{},
out_gemmak0_gemmmraw_gemmak1_grid_desc, Sequence<1>{},
make_tuple(AK0, GemmMPerBlock, AK1), Sequence<2>{},
Sequence<false, DoPadGemmM, false>{}); Sequence<3>{},
Sequence<4>{},
return out_gemmak0_gemmm_gemmak1_grid_desc; Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(AK1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
else if constexpr(NDimSpatial == 3)
{
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(
make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(
make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(
make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, AK0)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(AK1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
} }
} }
template <typename BLayout, template <typename BLayout,
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
is_same_v<BLayout, tensor_layout::convolution::GKYXC>, (is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>),
bool>::type = false> bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1( static auto MakeBDescriptor_BK0_N_BK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
...@@ -263,30 +481,35 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -263,30 +481,35 @@ struct TransformConvBwdDataToGemm_v1
const std::array<index_t, NDimSpatial>& /* input_right_pads */, const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes) const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ytilde = tildes[0]; index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[1]; index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1]; const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1]; const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2]; const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Ho = out_g_n_k_wos_lengths[3]; const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Wo = out_g_n_k_wos_lengths[4]; const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Y = wei_g_k_c_xs_lengths[3]; const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t X = wei_g_k_c_xs_lengths[4]; const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[1]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const index_t BK0 = K / BK1; const index_t BK0 = K / BK1;
// assume packed // assume packed
const auto wei_k_y_x_c_grid_desc = // k_y_x_c for 2d or k_z_y_x_c for 3d
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
...@@ -299,7 +522,7 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -299,7 +522,7 @@ struct TransformConvBwdDataToGemm_v1
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
...@@ -311,75 +534,163 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -311,75 +534,163 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde); const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde); const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// B weight tensor // B weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( if constexpr(NDimSpatial == 2)
wei_k_y_x_c_grid_desc, {
make_tuple(make_pass_through_transform(K), const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
make_embed_transform(make_tuple(YDot, YTilde), wei_grid_desc,
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple( make_tuple(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1), make_pass_through_transform(K),
Sequence<false, DoPadGemmN, false>{}); make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
return wei_gemmbk0_gemmn_gemmbk1_grid_desc; make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
GemmNPerBlock,
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
else if constexpr(NDimSpatial == 3)
{
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2>{},
Sequence<4>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{},
Sequence<>{},
Sequence<>{},
Sequence<5>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0),
GemmNPerBlock,
BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
} }
} }
template <typename CLayout, template <typename CLayout,
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> || (is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> || is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>), is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>),
bool>::type = false> bool>::type = false>
static auto static auto
...@@ -395,153 +706,309 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -395,153 +706,309 @@ struct TransformConvBwdDataToGemm_v1
const std::array<index_t, NDimSpatial>& input_right_pads, const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes) const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ytilde = tildes[0]; index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[1]; index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1]; const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2]; const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3]; const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Wi = in_g_n_c_wis_lengths[4]; const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Ho = out_g_n_k_wos_lengths[3]; const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Wo = out_g_n_k_wos_lengths[4]; const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Y = wei_g_k_c_xs_lengths[3]; const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t X = wei_g_k_c_xs_lengths[4]; const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadH = input_left_pads[0]; const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[1]; const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[1]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume strided // assume strided
const auto in_n_hi_wi_c_grid_desc = // n_hi_wi_c for 2d n_di_hi_wi_c for 3d
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), const auto in_grid_desc =
make_tuple(in_g_n_c_wis_strides[1], make_in_grid_desc<NDimSpatial, CLayout>(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
// C: input tensor // C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( if constexpr(NDimSpatial == 2)
in_n_hi_wi_c_grid_desc, {
make_tuple(make_pass_through_transform(N), const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), in_grid_desc,
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_tuple(
make_pass_through_transform(C)), make_pass_through_transform(N),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
in_n_y_ho_x_wo_c_grid_desc, make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
make_merge_transform(make_tuple(N, Ho, Wo)), in_n_y_ho_x_wo_c_grid_desc,
make_pass_through_transform(C)), make_tuple(make_freeze_transform(I0),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), make_freeze_transform(I0),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
in_gemmmraw_gemmnraw_grid_desc, make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{}); const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
return in_gemmm_gemmn_grid_desc; in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else if constexpr(NDimSpatial == 3)
{
// C: input tensor
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_x_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<0, 2, 4, 6>{},
Sequence<7>{}),
make_tuple(
Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde = const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor // only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// C: input tensor // C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( if constexpr(NDimSpatial == 2)
in_n_hi_wi_c_grid_desc, {
make_tuple(make_pass_through_transform(N), const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
make_pad_transform(Hi, InLeftPadH, InRightPadH), in_grid_desc,
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C)), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
make_embed_transform(make_tuple(YTilde, HTilde), transform_tensor_descriptor(
make_tuple(ConvDilationH, ConvStrideH)), in_n_hip_wip_c_grid_desc,
make_embed_transform(make_tuple(XTilde, WTilde), make_tuple(make_pass_through_transform(N),
make_tuple(ConvDilationW, ConvStrideW)), make_embed_transform(make_tuple(YTilde, HTilde),
make_pass_through_transform(C)), make_tuple(ConvDilationH, ConvStrideH)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, make_tuple(
make_tuple(make_pass_through_transform(N), Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
make_freeze_transform(i_xtilde), in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C)), make_freeze_transform(i_ytilde),
make_tuple(Sequence<0>{}, make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
Sequence<1>{}, make_freeze_transform(i_xtilde),
Sequence<2>{}, make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
Sequence<3>{}, make_pass_through_transform(C)),
Sequence<4>{}, make_tuple(Sequence<0>{},
Sequence<5>{}), Sequence<1>{},
make_tuple(Sequence<0>{}, Sequence<2>{},
Sequence<>{}, Sequence<3>{},
Sequence<1>{}, Sequence<4>{},
Sequence<>{}, Sequence<5>{}),
Sequence<2>{}, make_tuple(Sequence<0>{},
Sequence<3>{})); Sequence<>{},
Sequence<1>{},
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( Sequence<>{},
in_n_htildeslice_wtildeslice_c_grid_desc, Sequence<2>{},
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), Sequence<3>{}));
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{})); in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( make_pass_through_transform(C)),
in_gemmmraw_gemmnraw_grid_desc, make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(GemmMPerBlock, GemmNPerBlock), make_tuple(Sequence<0>{}, Sequence<1>{}));
Sequence<DoPadGemmM, DoPadGemmN>{});
const auto in_gemmm_gemmn_grid_desc =
return in_gemmm_gemmn_grid_desc; ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else if(NDimSpatial == 3)
{
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
} }
} }
......
...@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_ ...@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32 // convert bfp16 to int8 via fp32
template <> template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
......
...@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( ...@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#ifdef __int8__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1, std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC, NWC,
...@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( ...@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv2d backward data // conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef __int8__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv2d dl // conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( ...@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef __int8__
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( ...@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
// conv3d backward data // conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( ...@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef __int8__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC, NDHWC,
...@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( ...@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
} }
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
} }
#endif
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> && else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>) is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
...@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
} }
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
#endif
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>) is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
...@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
} }
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
} }
#endif
} }
return op_ptrs; return op_ptrs;
......
...@@ -77,7 +77,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( ...@@ -77,7 +77,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#ifdef __int8__
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
...@@ -118,6 +118,27 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( ...@@ -118,6 +118,27 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -183,26 +204,6 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( ...@@ -183,26 +204,6 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -388,6 +389,7 @@ struct DeviceOperationInstanceFactory< ...@@ -388,6 +389,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
} }
} }
#ifdef __int8__
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>) is_same_v<CDataType, int8_t>)
{ {
...@@ -420,7 +422,7 @@ struct DeviceOperationInstanceFactory< ...@@ -420,7 +422,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
// f16_f16_f32_f16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f16_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
#ifdef CK_WORKAROUND_SWDEV_3318619
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
#endif
// clang-format on
>;
// bf16_bf16_f32_bf16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bf16_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
#ifdef CK_WORKAROUND_SWDEV_3318619
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
#endif
// clang-format on
>;
// f32_f32_f32_f32
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f32_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>
#ifdef CK_WORKAROUND_SWDEV_3318619
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -16,7 +16,7 @@ namespace device { ...@@ -16,7 +16,7 @@ namespace device {
namespace instance { namespace instance {
// conv2d backward data // conv2d backward data
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( ...@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( ...@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// conv3d backward data
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename OutLayout, typename OutLayout,
typename WeiLayout, typename WeiLayout,
...@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory< ...@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2)
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{ {
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
{ is_same_v<OutLayout, GNHWK>)
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
} }
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutDataType, BF16>) is_same_v<OutLayout, NHWGK>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
} }
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> && else if constexpr(NumDimSpatial == 3)
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
{ is_same_v<OutLayout, GNDHWK>)
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
} }
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutDataType, BF16>) is_same_v<OutLayout, NDHWGK>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
} }
} }
......
...@@ -12,9 +12,42 @@ set(CK_DEVICE_INSTANCES) ...@@ -12,9 +12,42 @@ set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list}) FOREACH(subdir_path ${dir_list})
set(target_dir) set(target_dir)
IF(IS_DIRECTORY "${subdir_path}") IF(IS_DIRECTORY "${subdir_path}")
get_filename_component(target_dir ${subdir_path} NAME) set(cmake_instance)
add_subdirectory(${target_dir}) file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>) set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
#message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
#message("fp32 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
#message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
#message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES")
#message("instance should be built for all types!")
set(add_inst 1)
endif()
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
ENDIF() ENDIF()
ENDFOREACH() ENDFOREACH()
......
add_instance_library(device_batched_gemm_multi_d_instance set(BATCHED_GEMM_MULTID_INSTANCES)
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp endif()
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
) list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
endif()
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
add_instance_library(device_conv2d_bwd_data_instance set(CONV2D_BWD_DATA_INSTANCES)
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp endif()
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -151,3 +151,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -151,3 +151,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
add_instance_library(device_gemm_instance set(GEMM_INSTANCES)
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp)
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp)
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp)
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp endif()
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp)
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp)
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp)
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp)
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp endif()
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp)
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp)
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp)
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp)
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp)
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp)
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp)
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp)
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
endif()
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
)
set(ENABLE_PIPELINE_V2_OPT OFF) set(ENABLE_PIPELINE_V2_OPT OFF)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment