Unverified Commit a8629a98 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_v2r3_kpad_fix

parents 8dc713ea 94bfa502
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static inline constexpr ck::index_t NDimSpatial = 2;
using FP32 = float;
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
};
#define DefaultConvParams \
ck::utils::conv::ConvParam \
{ \
NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ExecutionConfig& config,
ck::utils::conv::ConvParam& conv_params)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
config = ExecutionConfig{};
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using InDataType = FP32;
using OutDataType = FP32;
using InLayout = ck::tensor_layout::convolution::GNHWC;
// clang-format off
using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
< NDimSpatial, InLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>;
// clang-format on
bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params)
{
const auto N = conv_params.N_;
const auto C = conv_params.C_;
const ck::index_t NDoHoWo =
N * ck::accumulate_n<ck::index_t>(
conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const ck::index_t CZYX =
C * ck::accumulate_n<ck::index_t>(
conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const auto in_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_params);
const auto out_desc = HostTensorDescriptor({NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
std::array<ck::index_t, 2> output_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(conv_params.input_spatial_lengths_, input_spatial_lengths);
copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_params.output_spatial_lengths_, output_spatial_lengths);
copy(in_desc.GetStrides(), input_g_n_c_wis_strides);
copy(out_desc.GetStrides(), output_m_k_strides);
copy(conv_params.conv_filter_strides_, conv_filter_strides);
copy(conv_params.conv_filter_dilations_, conv_filter_dilations);
copy(conv_params.input_left_pads_, input_left_pads);
copy(conv_params.input_right_pads_, input_right_pads);
Tensor<InDataType> in(in_desc);
Tensor<OutDataType> out_device(out_desc);
Tensor<OutDataType> out_host(out_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "out: " << out_device.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1: in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
default: in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
// reset input to zero
out_device_buf.SetZero();
static_assert(std::is_default_constructible_v<DeviceImgToColInstance>);
// do conv
auto img2col = DeviceImgToColInstance{};
auto invoker = img2col.MakeInvoker();
auto argument = img2col.MakeArgument(in_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
if(!img2col.IsSupportedArgument(argument))
{
std::cerr << "wrong! device_img2col with the specified compilation parameters does "
"not support this img2col problem"
<< std::endl;
return false;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t num_btype = NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
if(config.do_verification)
{
auto ref_image_to_column = ck::tensor_operation::host::
ReferenceImageToColumn<NDimSpatial, InLayout, InDataType, OutDataType>();
auto ref_invoker = ref_image_to_column.MakeInvoker();
auto ref_argument = ref_image_to_column.MakeArgument(in,
out_host,
conv_params.filter_spatial_lengths_,
conv_params.conv_filter_strides_,
conv_params.conv_filter_dilations_,
conv_params.input_left_pads_,
conv_params.input_right_pads_);
if(!ref_image_to_column.IsSupportedArgument(&ref_argument))
{
std::cerr << "wrong! ref_img2col with the specified compilation parameters does "
"not support this img2col problem"
<< std::endl;
return false;
}
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device.mData, out_host.mData);
}
return true;
}
int RunImageToColumnExample(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_params = DefaultConvParams;
if(!parse_cmd_args(argc, argv, config, conv_params))
{
return EXIT_FAILURE;
}
if(conv_params.num_dim_spatial_ != NDimSpatial)
{
std::cerr << "unsupported # of spatial dimensions" << std::endl;
return EXIT_FAILURE;
}
return !RunImageToColumn(config, conv_params);
}
int main(int argc, char* argv[]) { return RunImageToColumnExample(argc, argv); }
......@@ -7,20 +7,114 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
return(PROPAGATE result)
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir
......
......@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
......@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
namespace ck {
/**
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
* the data loaded from LDS to registers.
*
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
* between them in such a way that threads from the same group need the same chunk of either
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
* whole chunk from LDS to its own register space.
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
* of the chunk, and then share that data with other threads from the same lane group.
*
* Assumptions coming from the usage of DPP8:
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
* matrix A or B;
* - based on these values we determine which matrix to share.
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
* - we have to make sure that the data to split is divisible by the number of
* threads in the group.
*
* General algorithm:
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
* A and B are visible to the whole block, C is distributed among each thread
* Assume:
* 1. A:
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
* 2. ABlockBuffer is DynamicBuffer
* 2. B:
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
* 2. BBlockBuffer is DynamicBuffer
* 3. C:
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
* 2. CThreadBuffer is StaticBuffer
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
*/
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
{
using AIndex = MultiIndex<4>;
using BIndex = MultiIndex<4>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t BM1 = BM100 * BM101 * BM11;
static constexpr index_t BN1 = BN100 * BN101 * BN11;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
// threads in a lane group need the same chunk of B or A matrices and we can share them using
// DPP.
static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size);
static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false;
static constexpr bool ShareA = !ShareB;
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
// respectively) elements, so we split them between threads in lane group so each thread loads
// less data from LDS.
static constexpr index_t BM1PerThread =
ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11;
static constexpr index_t BN1PerThread =
ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11;
__host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()},
b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()}
{
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
BM10BN10ThreadClusterBN10Xs::Size() == 2,
"wrong!");
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
__device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()
{
const auto offsetBM0 = c_thread_origin_data_idx_[I0];
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] +
dpp8::get_thread_idx_in_lane_group() * BM1PerThread
: c_thread_origin_data_idx_[I1];
return make_tuple(0, offsetBM0, offsetBM1, 0);
}
__device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()
{
const auto offsetBN0 = c_thread_origin_data_idx_[I2];
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] +
dpp8::get_thread_idx_in_lane_group() * BN1PerThread
: c_thread_origin_data_idx_[I3];
return make_tuple(0, offsetBN0, offsetBN1, 0);
}
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction =
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>,
ShareA>{};
static_for<0, BN0, 1>{}([&](auto bn0) {
static_for<0, BM0, 1>{}([&](auto bm0) {
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, bm0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, bn0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(bm0, I0, bn0, I0));
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, bm0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, bn0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(bm0, I0, bn0, I0));
});
});
});
}
private:
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThread>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThread>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThread, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThread, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp"
namespace ck {
/**
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
* thread by sharing the data between threads in a lanegroup.
*
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
* In total, the algorithm runs using
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template <index_t BlockSize,
typename ABDataType,
typename AccDataType,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerDpp,
index_t NPerDpp,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto dpp_gemm = DppGemm<ABDataType, MPerDpp, NPerDpp, KPack>{};
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
MRepeat * NRepeat,
dpp_gemm.GetRegSizePerDpp(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
const auto dpp_a_idx_k = dpp_a_idx[I0];
const auto dpp_a_idx_m = dpp_a_idx[I1];
return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
}
__device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
const auto dpp_b_idx_k = dpp_b_idx[I0];
const auto dpp_b_idx_n = dpp_b_idx[I1];
return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
const auto blk_m_offset = blk_idx[I0];
const auto blk_n_offset = blk_idx[I1];
constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_m_offset))[I0];
const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_n_offset))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"Wrong! Block descriptors should be known at the time of compilation.");
#if defined(__HIP_DEVICE_COMPILE__)
// Host wave size can be different than the device one and this assert could fail for host,
// but it does matter only for device.
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
#endif
static_assert(MPerBlock % (MPerDpp * MRepeat) == 0,
"Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat.");
static_assert(NPerBlock % (NPerDpp * NRepeat) == 0,
"Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat.");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
constexpr auto M = c_m_n_tblk_lens[I0];
constexpr auto N = c_m_n_tblk_lens[I1];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
constexpr auto M = c_m_n_tblk_lens[I0];
constexpr auto N = c_m_n_tblk_lens[I1];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerDpp>{},
Number<NPerDpp>{}));
return c_block_desc_m0_n0_m1_n1_m2_n2;
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerDpp>{},
Number<NPerDpp>{}));
return c_block_desc_g_m0_n0_m1_n1_m2_n2;
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return c_grid_desc_m0_n0_m1_n1_m2_n2;
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerDpp>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerDpp>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<ABDataType, KPack> a_thread_vec;
vector_type<ABDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using dpp_input_type =
typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
dpp_gemm.template Run(a_thread_vec.template AsType<dpp_input_type>(),
b_thread_vec.template AsType<dpp_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// C[M, N, NumRegDpp]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
ABDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
ABDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()};
};
} // namespace ck
......@@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
......
......@@ -4,27 +4,13 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* \brief Image to column.
*
* This Device operator converts image ([G, N, Di, Hi, Wi, C]) to the gemm
* problem([N * Do * Ho * Wo, Z * Y * X * C]). G must be equal to 1.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam InputLayout Input Layout.
* \tparam InputDataType Input Data Type.
* \tparam OutputDataType Output Data Type.
*/
template <index_t NDimSpatial,
typename InputLayout,
typename InputDataType,
typename OutputDataType>
struct DeviceImageToColumn : public BaseOperator
{
/**
* \brief Make argument pointer for image to column.
*
* \param p_in A pointer to the device memory of the input image.
* \param p_out A pointer to the device memory of the output.
* \param N Convolution batch size.
* \param C Convolution number of channels.
* \param input_spatial_lengths Input spatial lengths.
* \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths.
* \param input_g_n_c_wis_strides Input strides in order [G, N, C, D, H, W].
* \param output_m_k_strides Output strides.
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads.
* \param input_right_pads Convolution right pads.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
void* p_out,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
enum struct GemmDlAlgorithm
{
Default, // Uses DOT vector instructions
Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -11,7 +11,6 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
......@@ -60,7 +59,6 @@ template <
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
......@@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
GemmDlAlg>;
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
......@@ -276,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
M_raw_{M},
N_raw_{N},
K_raw_{K},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
......@@ -317,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
index_t M01_;
index_t N01_;
index_t M_raw_;
index_t N_raw_;
index_t K_raw_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -375,8 +379,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true,
GemmDlAlg>;
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -402,8 +405,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false,
GemmDlAlg>;
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -429,8 +431,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true,
GemmDlAlg>;
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -456,8 +457,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false,
GemmDlAlg>;
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -492,14 +492,48 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
// Make sure that the M, N, K dimensions before padding are divisible by respective vector
// lengths.
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto A_K_vec_length =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
if(arg.K_raw_ % A_K_vec_length != 0)
{
return false;
}
}
else
{
constexpr auto A_M_vec_lenght =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
if(arg.M_raw_ % A_M_vec_lenght != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
constexpr auto B_N_vec_lenght =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
if(arg.N_raw_ % B_N_vec_lenght != 0)
{
return false;
}
}
else
{
if(ck::get_device_name() == "gfx1030")
constexpr auto B_K_vec_length =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
if(arg.K_raw_ % B_K_vec_length != 0)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
return false;
}
return false;
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
GemmDlAlgorithm::Dpp8>
{
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDlDpp8"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerDpp,
ck::index_t NPerDpp,
ck::index_t MDppPerWave,
ck::index_t NDppPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1,
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmDpp : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp<
BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
MPerBlock,
NPerBlock,
KPerBlock,
MPerDpp,
NPerDpp,
AK1,
BK1,
MDppPerWave,
NDppPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumPrefetch,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
karg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
}
const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{
const auto kernel = kernel_gemm_dpp<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_gemm_dpp<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{
return GridwiseGemm::CheckValidity(karg);
}
return false;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemmDpp"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerDpp << ", "
<< NPerDpp << ", "
<< MDppPerWave << ", "
<< MDppPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -144,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
......@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
ComputeDataType,
AccDataType,
......
......@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
BK1,
MPerBlock,
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN>{};
......
......@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -72,6 +73,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......@@ -96,9 +100,23 @@ __global__ void
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
#endif
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
......@@ -134,29 +152,46 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
: public DeviceGroupedConvBwdWeight<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl;
// 1d
static constexpr bool is_NWGK_GKXC_NWGC =
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
static constexpr bool is_GNWK_GKXC_GNWC =
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
// 2d
static constexpr bool is_NHWGK_GKYXC_NHWGC =
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
static constexpr bool is_GNHWK_GKYXC_GNHWC =
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
using ADataType = OutDataType;
using BDataType = InDataType;
......@@ -176,6 +211,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto spatial_offset = I3;
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
......@@ -195,12 +232,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -209,90 +246,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using namespace ck;
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t N = a_g_n_c_wis_lengths[I1];
const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset];
const index_t InLeftPadW = input_left_pads[I0];
const index_t InRightPadW = input_right_pads[I0];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t ConvDilationW = conv_filter_dilations[I0];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C));
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weights tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
else
{
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -321,38 +370,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -361,103 +415,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using namespace ck;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t N = a_g_n_c_wis_lengths[I1];
const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t InLeftPadH = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[I1];
const index_t InRightPadH = input_right_pads[I0];
const index_t InRightPadW = input_right_pads[I1];
const index_t ConvStrideH = conv_filter_strides[I0];
const index_t ConvStrideW = conv_filter_strides[I1];
const index_t ConvDilationH = conv_filter_dilations[I0];
const index_t ConvDilationW = conv_filter_dilations[I1];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
else
{
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -488,39 +550,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -529,110 +596,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using namespace ck;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t N = a_g_n_c_wis_lengths[I1];
const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
const index_t InLeftPadD = input_left_pads[I0];
const index_t InLeftPadH = input_left_pads[I1];
const index_t InLeftPadW = input_left_pads[I2];
const index_t InRightPadD = input_right_pads[I0];
const index_t InRightPadH = input_right_pads[I1];
const index_t InRightPadW = input_right_pads[I2];
const index_t ConvStrideD = conv_filter_strides[I0];
const index_t ConvStrideH = conv_filter_strides[I1];
const index_t ConvStrideW = conv_filter_strides[I2];
const index_t ConvDilationD = conv_filter_dilations[I0];
const index_t ConvDilationH = conv_filter_dilations[I1];
const index_t ConvDilationW = conv_filter_dilations[I2];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
else
{
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkpad_gemmmpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
Sequence<true, true>{});
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
out_gemmkpad_gemmmpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -672,27 +749,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
Sequence<true, true>{});
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
in_gemmkpad_gemmnpad_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmmpad_gemmnpad_grid_desc);
}
} // function end
......@@ -701,22 +783,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
{1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
1,
1,
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
......@@ -785,11 +867,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......@@ -809,38 +891,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{},
filter_spatial_lengths_{},
output_spatial_lengths_{},
Conv_G_{a_g_n_c_wis_lengths[I0]},
Conv_K_{b_g_k_c_xs_lengths[I1]},
Conv_C_{a_g_n_c_wis_lengths[I2]},
filter_lengths_{b_g_k_c_xs_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -863,24 +931,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
}
const ADataType* p_a_grid_;
......@@ -908,13 +961,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// for checking IsSupportedArgument()
const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_;
const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
......@@ -1036,10 +1086,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
// DL version only supports split_k equal to 1
if(arg.k_batch_ != 1)
return false;
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
{
return false;
}
......@@ -1050,8 +1104,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
{
return false;
}
......@@ -1206,7 +1261,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
str << "DeviceGroupedConvBwdWeight_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
......@@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
......
......@@ -193,6 +193,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout,
......@@ -217,6 +218,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeType,
AccDataType,
CShuffleDataType,
DsDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename GridwiseImageToColumnKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_image_to_column(const InputGridDesc in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseImageToColumnKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
#else
ignore = in_grid_desc;
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = block_2_tile_map;
#endif
}
// Image to column for input layout NDHWC:
// input : input image [N, Di, Hi, Wi, C],
// output : output image [N * Do * Ho * Wo, Z * Y * X * C]
template <index_t NDimSpatial,
typename InputLayout,
typename InputDataType,
typename OutputDataType,
index_t BlockSize,
index_t MPerBlock,
index_t KPerBlock,
typename ThreadClusterLengths,
index_t ScalarPerVector>
struct DeviceImageToColumnImpl
: public DeviceImageToColumn<NDimSpatial, InputLayout, InputDataType, OutputDataType>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
MPerBlock, 0 /* NPerBlock*/, KPerBlock};
// Use MakeADescriptor_M_K from grouped convolution forward
static auto
MakeInputDescriptor_M_K(const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{1};
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{1};
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{1};
auto copy = [](const auto& x, auto& y, index_t dst_offset) {
std::copy(x.begin(), x.end(), y.begin() + dst_offset);
};
constexpr index_t spatial_offset = 3;
copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset);
copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset);
copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset);
// fill only significant values (C and N)
a_g_n_c_wis_lengths[I1] = N;
a_g_n_c_wis_lengths[I2] = C;
b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<InputLayout>(
a_g_n_c_wis_lengths,
input_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
static auto
MakeOutDescriptor_M_K(const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& output_m_k_strides)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(output_m_k_strides[I0], output_m_k_strides[I1]));
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
return desc_m_k;
}
using InputGridDesc =
remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}, {}, {}))>;
using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(1, 1, {}, {}, {}))>;
using Block2ETileMap = remove_cvref_t<
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>;
using GridwiseImageToColumnKernel = GridwiseImageToColumn<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
Block2ETileMap>;
struct Argument : public BaseArgument
{
Argument(const void* p_in, // input image
void* p_out, // output image
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
: C_(C),
X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)},
input_g_n_c_wis_strides_{input_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
in_grid_desc_m_k_ = MakeInputDescriptor_M_K(N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
N, C, filter_spatial_lengths, output_spatial_lengths, output_m_k_strides);
}
void Print() const
{
std::cout << in_grid_desc_m_k_ << std::endl;
std::cout << out_grid_desc_m_k_ << std::endl;
}
const ck::index_t C_;
const ck::index_t X_;
const InputDataType* p_in_;
OutputDataType* p_out_;
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<index_t, NDimSpatial>& input_left_pads_;
const std::array<index_t, NDimSpatial>& input_right_pads_;
InputGridDesc in_grid_desc_m_k_;
OutputGridDesc out_grid_desc_m_k_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
arg.out_grid_desc_m_k_);
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
const auto kernel = kernel_image_to_column<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
GridwiseImageToColumnKernel>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.in_grid_desc_m_k_,
arg.p_in_,
arg.out_grid_desc_m_k_,
arg.p_out_,
block_2_tile_map);
return elapsed_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const Argument& arg)
{
using namespace tensor_layout::convolution;
if(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>))
{
return false;
}
if(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
bool is_w_packed = arg.input_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
bool is_c_packed = arg.input_g_n_c_wis_strides_[I2] == 1;
// check vector acces with c not packed
if(!is_c_packed && ScalarPerVector != 1)
return false;
// check vector access of filter window row (only C if C is not packed)
if(!is_w_packed && arg.C_ % ScalarPerVector != 0)
return false;
// check vector access of filter window row (X * C)
if(arg.X_ * arg.C_ % ScalarPerVector != 0)
return false;
// check vector access of pads (w_pad_left/w_pad_right * C)
if(w_pad_left * arg.C_ % ScalarPerVector != 0 ||
w_pad_right * arg.C_ % ScalarPerVector != 0)
return false;
// check vector access of with stride and pad
if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0)
return false;
// check vector access of with dilation
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
return false;
return GridwiseImageToColumnKernel::CheckValidity(arg.in_grid_desc_m_k_,
arg.out_grid_desc_m_k_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
return Argument{static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out),
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) override
{
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out),
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceImageToColumn"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< KPerBlock << ", "
<< ScalarPerVector
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -186,6 +186,13 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
};
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
{
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
};
float alpha_;
float beta_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment