Unverified Commit 9684677a authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into hip_tensor_permute

parents 36f6966a 98fd41f5
...@@ -33,7 +33,8 @@ template <index_t NumDimM, ...@@ -33,7 +33,8 @@ template <index_t NumDimM,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator struct DeviceContractionMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
......
...@@ -14,11 +14,12 @@ namespace device { ...@@ -14,11 +14,12 @@ namespace device {
/** /**
* \brief Convolution Tensor Rearrange. * \brief Convolution Tensor Rearrange.
* *
* This Device operator supports conversion image ([G, N, Di, Hi, Wi, C]) to * This Device operator supports converting an image to
* the gemm problem([N * Do * Ho * Wo, Z * Y * X * C]) (Image to Column) and * the GEMM representation (Image to Column) and
* conversion gemm form to the image (Column to Image). * converting a GEMM form to the image (Column to Image).
* * Supported layouts:
* Note that G must be equal to 1. * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C]
* [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C]
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Input Layout. * \tparam ImageLayout Input Layout.
...@@ -39,13 +40,14 @@ struct DeviceConvTensorRearrange : public BaseOperator ...@@ -39,13 +40,14 @@ struct DeviceConvTensorRearrange : public BaseOperator
* *
* \param p_in A pointer to the device memory of the input image. * \param p_in A pointer to the device memory of the input image.
* \param p_out A pointer to the device memory of the output. * \param p_out A pointer to the device memory of the output.
* \param G Convolution number of groups.
* \param N Convolution batch size. * \param N Convolution batch size.
* \param C Convolution number of channels. * \param C Convolution number of channels.
* \param input_spatial_lengths Input spatial lengths. * \param input_spatial_lengths Input spatial lengths.
* \param filter_spatial_lengths Filter spatial lengths. * \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths. * \param output_spatial_lengths Output spatial lengths.
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
* \param gemm_m_k_strides Gemm form strides. * \param gemm_g_m_k_strides Gemm form strides.
* \param conv_filter_strides Convolution filter strides. * \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations. * \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads. * \param input_left_pads Convolution left pads.
...@@ -55,13 +57,14 @@ struct DeviceConvTensorRearrange : public BaseOperator ...@@ -55,13 +57,14 @@ struct DeviceConvTensorRearrange : public BaseOperator
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_out, void* p_out,
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
......
...@@ -17,15 +17,18 @@ ...@@ -17,15 +17,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" #include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Image to column for input layout NDHWC: // Column to Image:
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C] // input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
// output : image [N, Di, Hi, Wi, C] // output : input image [G, N, Di, Hi, Wi, C]
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
// output : input image [N, Di, Hi, Wi, G, C]
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ImageLayout, typename ImageLayout,
typename InputDataType, typename InputDataType,
...@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl ...@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl
OutputDataType, OutputDataType,
conv_tensor_rearrange_op::ColumnToImage> conv_tensor_rearrange_op::ColumnToImage>
{ {
static constexpr bool is_NSpatialGC =
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
static constexpr bool is_GNSpatialC =
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl ...@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& independent_filters, const std::array<index_t, NDimSpatial>& independent_filters,
const std::array<index_t, NDimSpatial>& effs) const std::array<index_t, NDimSpatial>& effs)
{ {
...@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl ...@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t NStride = DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1]; const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
// Calculate the appropriate stride for each set of independent filters // Calculate the appropriate stride for each set of independent filters
// in each dimension // in each dimension
const index_t WStride = const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * gemm_m_k_strides[I0]; gemm_g_m_k_strides[I1];
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
output_spatial_lengths[XIdx] * gemm_m_k_strides[I0]; output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
gemm_m_k_strides[I0]; gemm_g_m_k_strides[I1];
// Create descriptor for independent filters in each dimension and // Create descriptor for independent filters in each dimension and
// then merge them into column form // then merge them into column form
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const auto desc_gemm_form = const auto desc_gemm_form =
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
make_tuple(NStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
...@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl ...@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl
{ {
const auto desc_gemm_form = make_naive_tensor_descriptor( const auto desc_gemm_form = make_naive_tensor_descriptor(
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
make_tuple(NStride, HStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform( make_tuple(make_merge_transform(
...@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl ...@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl
independent_filters[YIdx], independent_filters[YIdx],
independent_filters[XIdx], independent_filters[XIdx],
CZYX), CZYX),
make_tuple(NStride, DStride, HStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, make_tuple(make_merge_transform(make_tuple(N,
...@@ -252,34 +263,38 @@ struct DeviceColumnToImageImpl ...@@ -252,34 +263,38 @@ struct DeviceColumnToImageImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>( decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
InputGridDesc{}))>; InputGridDesc{}))>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc, using GridwiseTensorRearrangeKernel =
InputDataType, GridwiseTensorRearrange<InputGridDesc,
OutputGridDesc, InputDataType,
OutputDataType, OutputGridDesc,
BlockSize, OutputDataType,
MPerBlock, BlockSize,
KPerBlock, MPerBlock,
ThreadClusterLengths, KPerBlock,
ScalarPerVector, ThreadClusterLengths,
InMemoryDataOperationEnum::Add, ScalarPerVector,
Block2ETileMap>; InMemoryDataOperationEnum::Add,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_in, // input image Argument(const void* p_in, // input image
void* p_out, // output image void* p_out, // output image
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
: C_(C), : G_(G),
C_(C),
X_(filter_spatial_lengths[NDimSpatial - I1]), X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)}, p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)}, p_out_{static_cast<OutputDataType*>(p_out)},
...@@ -289,6 +304,9 @@ struct DeviceColumnToImageImpl ...@@ -289,6 +304,9 @@ struct DeviceColumnToImageImpl
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
const index_t x_eff = const index_t x_eff =
(filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1; (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
const index_t y_eff = const index_t y_eff =
...@@ -354,7 +372,7 @@ struct DeviceColumnToImageImpl ...@@ -354,7 +372,7 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
conv_filter_strides, conv_filter_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
independent_filters, independent_filters,
effs); effs);
const auto out_grid_desc_m_k = const auto out_grid_desc_m_k =
...@@ -387,10 +405,9 @@ struct DeviceColumnToImageImpl ...@@ -387,10 +405,9 @@ struct DeviceColumnToImageImpl
// Memory offsets to next set of independent filters, // Memory offsets to next set of independent filters,
// move to independent filters in each dimension // move to independent filters in each dimension
const index_t in_offset = const index_t in_offset =
x_idx * gemm_m_k_strides[0] + (x_idx + y_idx * output_spatial_lengths[XIdx] +
y_idx * gemm_m_k_strides[0] * output_spatial_lengths[XIdx] + z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
z_idx * gemm_m_k_strides[0] * output_spatial_lengths[YIdx] * gemm_g_m_k_strides[I1];
output_spatial_lengths[XIdx];
// Move to independent filters in appropriate dimensions // Move to independent filters in appropriate dimensions
const index_t out_offset = const index_t out_offset =
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
...@@ -417,6 +434,7 @@ struct DeviceColumnToImageImpl ...@@ -417,6 +434,7 @@ struct DeviceColumnToImageImpl
} }
} }
const ck::index_t G_;
const ck::index_t C_; const ck::index_t C_;
const ck::index_t X_; const ck::index_t X_;
...@@ -434,6 +452,8 @@ struct DeviceColumnToImageImpl ...@@ -434,6 +452,8 @@ struct DeviceColumnToImageImpl
std::vector<const InputDataType*> p_in_container_; std::vector<const InputDataType*> p_in_container_;
std::vector<OutputDataType*> p_out_container_; std::vector<OutputDataType*> p_out_container_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -451,6 +471,7 @@ struct DeviceColumnToImageImpl ...@@ -451,6 +471,7 @@ struct DeviceColumnToImageImpl
OutputGridDesc, OutputGridDesc,
OutputDataType, OutputDataType,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
GridwiseTensorRearrangeKernel>; GridwiseTensorRearrangeKernel>;
// Execute each set of independent filters // Execute each set of independent filters
...@@ -460,7 +481,7 @@ struct DeviceColumnToImageImpl ...@@ -460,7 +481,7 @@ struct DeviceColumnToImageImpl
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>( BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
arg.out_grid_desc_m_k_container_[i]); arg.out_grid_desc_m_k_container_[i]);
const index_t grid_size = const index_t grid_size =
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]); block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_;
elapsed_time += launch_and_time_kernel(stream_config, elapsed_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
...@@ -470,7 +491,9 @@ struct DeviceColumnToImageImpl ...@@ -470,7 +491,9 @@ struct DeviceColumnToImageImpl
arg.p_in_container_[i], arg.p_in_container_[i],
arg.out_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i],
arg.p_out_container_[i], arg.p_out_container_[i],
block_2_tile_map); arg.G_,
block_2_tile_map,
arg.compute_ptr_offset_of_batch_);
} }
return elapsed_time; return elapsed_time;
} }
...@@ -485,8 +508,7 @@ struct DeviceColumnToImageImpl ...@@ -485,8 +508,7 @@ struct DeviceColumnToImageImpl
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
using namespace tensor_layout::convolution; using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> || if constexpr(!(is_NSpatialGC || is_GNSpatialC))
std::is_same_v<ImageLayout, GNDHWC>))
{ {
return false; return false;
} }
...@@ -534,13 +556,14 @@ struct DeviceColumnToImageImpl ...@@ -534,13 +556,14 @@ struct DeviceColumnToImageImpl
static auto MakeArgument(const void* p_in, // input image static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image void* p_out, // output image
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -548,13 +571,14 @@ struct DeviceColumnToImageImpl ...@@ -548,13 +571,14 @@ struct DeviceColumnToImageImpl
{ {
return Argument{static_cast<const InputDataType*>(p_in), return Argument{static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out), static_cast<OutputDataType*>(p_out),
G,
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -566,13 +590,14 @@ struct DeviceColumnToImageImpl ...@@ -566,13 +590,14 @@ struct DeviceColumnToImageImpl
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image void* p_out, // output image
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -580,13 +605,14 @@ struct DeviceColumnToImageImpl ...@@ -580,13 +605,14 @@ struct DeviceColumnToImageImpl
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out), static_cast<OutputDataType*>(p_out),
G,
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -145,7 +145,8 @@ template <index_t NumDimM, ...@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> typename ComputeDataType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContractionMultipleD_Xdl_CShuffle struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM, : public DeviceContractionMultipleD<NumDimM,
NumDimN, NumDimN,
...@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation,
ComputeDataType>
{ {
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
...@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
......
...@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
} }
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940") else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>)) is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
...@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle" str << "DeviceGemm_Xdl_CShuffle"
<< "<" << "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
...@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: " << "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];; << PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -59,7 +59,8 @@ template <typename ADataType, ...@@ -59,7 +59,8 @@ template <typename ADataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType, typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout, BLayout,
...@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams. // TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1; static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
...@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t MPadded_, index_t MPadded_,
index_t NPadded_, index_t NPadded_,
index_t KPadded_, index_t KPadded_,
index_t K0_, index_t K0Padded_,
index_t k_batch_, index_t k_batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
...@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_, MPadded_,
NPadded_, NPadded_,
KPadded_, KPadded_,
K0_, K0Padded_,
k_batch_), k_batch_),
a_element_op(a_element_op_), a_element_op(a_element_op_),
b_element_op(b_element_op_), b_element_op(b_element_op_),
...@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto b2c_map = DefaultBlock2CTileMap{}; const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0; const auto K0Padded = karg.K0Padded;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
float ave_time = 0; float ave_time = 0;
...@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
} }
// polymorphic // polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH); const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
...@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
m_padded, m_padded,
n_padded, n_padded,
k_padded, k_padded,
k0, k0_padded,
K_BATCH}; K_BATCH};
gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(
...@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH); const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
...@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded; karg.KPadded = k_padded;
karg.K0 = k0; karg.K0Padded = k0_padded;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
...@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0; index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
K0 = karg.K0; K0 = karg.K0Padded;
bool not_all_have_main_k0_block_loop_same = bool not_all_have_main_k0_block_loop_same =
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0); all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
......
...@@ -186,6 +186,25 @@ struct Bilinear ...@@ -186,6 +186,25 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1)); y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
}; };
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x0_tmp = type_convert<float>(x0);
const float x1_tmp = type_convert<float>(x1);
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
y = type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
y = y_tmp;
};
template <> template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
......
...@@ -311,6 +311,71 @@ struct AddAddFastGelu ...@@ -311,6 +311,71 @@ struct AddAddFastGelu
} }
}; };
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
{
}
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<half_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<bhalf_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
int8_t& e, const int8_t& c, const float& d0, const float& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<int8_t>(result);
}
const float alpha1_;
const float alpha2_;
};
struct Normalize struct Normalize
{ {
// FIXME: is double absolutely necessary? // FIXME: is double absolutely necessary?
......
This diff is collapsed.
...@@ -25,6 +25,8 @@ using BF8 = ck::bf8_t; ...@@ -25,6 +25,8 @@ using BF8 = ck::bf8_t;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
using BF16_Tuple = ck::Tuple<BF16>;
using F16_Tuple = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment