Commit 648f1f13 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/gemm_tile_loop

parents 4e5190f5 cb538740
......@@ -193,6 +193,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout,
......@@ -217,6 +218,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeType,
AccDataType,
CShuffleDataType,
DsDataType,
......
......@@ -5,64 +5,41 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename GridwiseImageToColumnKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_image_to_column(const InputGridDesc in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseImageToColumnKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
#else
ignore = in_grid_desc;
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = block_2_tile_map;
#endif
}
// Image to column for input layout NDHWC:
// input : input image [N, Di, Hi, Wi, C],
// output : output image [N * Do * Ho * Wo, Z * Y * X * C]
// input : input image [N, Di, Hi, Wi, C]
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
template <index_t NDimSpatial,
typename InputLayout,
typename ImageLayout,
typename InputDataType,
typename OutputDataType,
index_t BlockSize,
index_t MPerBlock,
index_t KPerBlock,
typename ThreadClusterLengths,
index_t ScalarPerVector>
index_t ScalarPerVector,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct DeviceImageToColumnImpl
: public DeviceImageToColumn<NDimSpatial, InputLayout, InputDataType, OutputDataType>
: public DeviceConvTensorRearrange<NDimSpatial,
ImageLayout,
InputDataType,
OutputDataType,
conv_tensor_rearrange_op::ImageToColumn>
{
static constexpr auto I0 = Number<0>{};
......@@ -83,7 +60,7 @@ struct DeviceImageToColumnImpl
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -110,9 +87,9 @@ struct DeviceImageToColumnImpl
c_g_n_k_wos_lengths[I1] = N;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<InputLayout>(
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
input_g_n_c_wis_strides,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
......@@ -132,7 +109,7 @@ struct DeviceImageToColumnImpl
const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& output_m_k_strides)
const std::array<index_t, 2>& gemm_m_k_strides)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
......@@ -141,7 +118,7 @@ struct DeviceImageToColumnImpl
C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(output_m_k_strides[I0], output_m_k_strides[I1]));
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
return desc_m_k;
......@@ -155,7 +132,7 @@ struct DeviceImageToColumnImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>;
using GridwiseImageToColumnKernel = GridwiseImageToColumn<InputGridDesc,
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
......@@ -164,19 +141,20 @@ struct DeviceImageToColumnImpl
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
Block2ETileMap>;
struct Argument : public BaseArgument
{
Argument(const void* p_in, // input image
void* p_out, // output image
void* p_out, // gemm form
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -185,7 +163,7 @@ struct DeviceImageToColumnImpl
X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)},
input_g_n_c_wis_strides_{input_g_n_c_wis_strides},
image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
......@@ -197,7 +175,7 @@ struct DeviceImageToColumnImpl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
image_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
......@@ -205,7 +183,7 @@ struct DeviceImageToColumnImpl
input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
N, C, filter_spatial_lengths, output_spatial_lengths, output_m_k_strides);
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides);
}
void Print() const
......@@ -220,7 +198,7 @@ struct DeviceImageToColumnImpl
const InputDataType* p_in_;
OutputDataType* p_out_;
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides_;
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<index_t, NDimSpatial>& input_left_pads_;
......@@ -243,12 +221,12 @@ struct DeviceImageToColumnImpl
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
arg.out_grid_desc_m_k_);
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
const auto kernel = kernel_image_to_column<InputGridDesc,
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
GridwiseImageToColumnKernel>;
GridwiseTensorRearrangeKernel>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -273,12 +251,8 @@ struct DeviceImageToColumnImpl
bool IsSupportedArgument(const Argument& arg)
{
using namespace tensor_layout::convolution;
if(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>))
{
return false;
}
if(!(NDimSpatial >= 1 && NDimSpatial <= 3))
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
......@@ -287,8 +261,8 @@ struct DeviceImageToColumnImpl
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
bool is_w_packed = arg.input_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
bool is_c_packed = arg.input_g_n_c_wis_strides_[I2] == 1;
bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
// check vector acces with c not packed
if(!is_c_packed && ScalarPerVector != 1)
......@@ -310,7 +284,7 @@ struct DeviceImageToColumnImpl
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
return false;
return GridwiseImageToColumnKernel::CheckValidity(arg.in_grid_desc_m_k_,
return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_,
arg.out_grid_desc_m_k_);
}
......@@ -320,14 +294,14 @@ struct DeviceImageToColumnImpl
}
static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image
void* p_out, // gemm form
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -340,8 +314,8 @@ struct DeviceImageToColumnImpl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
image_g_n_c_wis_strides,
gemm_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -352,14 +326,14 @@ struct DeviceImageToColumnImpl
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image
void* p_out, // gemm form
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -372,8 +346,8 @@ struct DeviceImageToColumnImpl
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
image_g_n_c_wis_strides,
gemm_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsTuple() { return true; }
};
template <>
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment