Unverified Commit 9684677a authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into hip_tensor_permute

parents 36f6966a 98fd41f5
...@@ -33,7 +33,8 @@ template <index_t NumDimM, ...@@ -33,7 +33,8 @@ template <index_t NumDimM,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator struct DeviceContractionMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
......
...@@ -14,11 +14,12 @@ namespace device { ...@@ -14,11 +14,12 @@ namespace device {
/** /**
* \brief Convolution Tensor Rearrange. * \brief Convolution Tensor Rearrange.
* *
* This Device operator supports conversion image ([G, N, Di, Hi, Wi, C]) to * This Device operator supports converting an image to
* the gemm problem([N * Do * Ho * Wo, Z * Y * X * C]) (Image to Column) and * the GEMM representation (Image to Column) and
* conversion gemm form to the image (Column to Image). * converting a GEMM form to the image (Column to Image).
* * Supported layouts:
* Note that G must be equal to 1. * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C]
* [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C]
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Input Layout. * \tparam ImageLayout Input Layout.
...@@ -39,13 +40,14 @@ struct DeviceConvTensorRearrange : public BaseOperator ...@@ -39,13 +40,14 @@ struct DeviceConvTensorRearrange : public BaseOperator
* *
* \param p_in A pointer to the device memory of the input image. * \param p_in A pointer to the device memory of the input image.
* \param p_out A pointer to the device memory of the output. * \param p_out A pointer to the device memory of the output.
* \param G Convolution number of groups.
* \param N Convolution batch size. * \param N Convolution batch size.
* \param C Convolution number of channels. * \param C Convolution number of channels.
* \param input_spatial_lengths Input spatial lengths. * \param input_spatial_lengths Input spatial lengths.
* \param filter_spatial_lengths Filter spatial lengths. * \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths. * \param output_spatial_lengths Output spatial lengths.
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
* \param gemm_m_k_strides Gemm form strides. * \param gemm_g_m_k_strides Gemm form strides.
* \param conv_filter_strides Convolution filter strides. * \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations. * \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads. * \param input_left_pads Convolution left pads.
...@@ -55,13 +57,14 @@ struct DeviceConvTensorRearrange : public BaseOperator ...@@ -55,13 +57,14 @@ struct DeviceConvTensorRearrange : public BaseOperator
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_out, void* p_out,
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
......
...@@ -17,15 +17,18 @@ ...@@ -17,15 +17,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" #include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Image to column for input layout NDHWC: // Column to Image:
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C] // input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
// output : image [N, Di, Hi, Wi, C] // output : input image [G, N, Di, Hi, Wi, C]
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
// output : input image [N, Di, Hi, Wi, G, C]
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ImageLayout, typename ImageLayout,
typename InputDataType, typename InputDataType,
...@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl ...@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl
OutputDataType, OutputDataType,
conv_tensor_rearrange_op::ColumnToImage> conv_tensor_rearrange_op::ColumnToImage>
{ {
static constexpr bool is_NSpatialGC =
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
static constexpr bool is_GNSpatialC =
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl ...@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& independent_filters, const std::array<index_t, NDimSpatial>& independent_filters,
const std::array<index_t, NDimSpatial>& effs) const std::array<index_t, NDimSpatial>& effs)
{ {
...@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl ...@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t NStride = DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1]; const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
// Calculate the appropriate stride for each set of independent filters // Calculate the appropriate stride for each set of independent filters
// in each dimension // in each dimension
const index_t WStride = const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * gemm_m_k_strides[I0]; gemm_g_m_k_strides[I1];
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
output_spatial_lengths[XIdx] * gemm_m_k_strides[I0]; output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
gemm_m_k_strides[I0]; gemm_g_m_k_strides[I1];
// Create descriptor for independent filters in each dimension and // Create descriptor for independent filters in each dimension and
// then merge them into column form // then merge them into column form
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const auto desc_gemm_form = const auto desc_gemm_form =
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
make_tuple(NStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
...@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl ...@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl
{ {
const auto desc_gemm_form = make_naive_tensor_descriptor( const auto desc_gemm_form = make_naive_tensor_descriptor(
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
make_tuple(NStride, HStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform( make_tuple(make_merge_transform(
...@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl ...@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl
independent_filters[YIdx], independent_filters[YIdx],
independent_filters[XIdx], independent_filters[XIdx],
CZYX), CZYX),
make_tuple(NStride, DStride, HStride, WStride, gemm_m_k_strides[I1])); make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
desc_gemm_form, desc_gemm_form,
make_tuple(make_merge_transform(make_tuple(N, make_tuple(make_merge_transform(make_tuple(N,
...@@ -252,34 +263,38 @@ struct DeviceColumnToImageImpl ...@@ -252,34 +263,38 @@ struct DeviceColumnToImageImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>( decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
InputGridDesc{}))>; InputGridDesc{}))>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc, using GridwiseTensorRearrangeKernel =
InputDataType, GridwiseTensorRearrange<InputGridDesc,
OutputGridDesc, InputDataType,
OutputDataType, OutputGridDesc,
BlockSize, OutputDataType,
MPerBlock, BlockSize,
KPerBlock, MPerBlock,
ThreadClusterLengths, KPerBlock,
ScalarPerVector, ThreadClusterLengths,
InMemoryDataOperationEnum::Add, ScalarPerVector,
Block2ETileMap>; InMemoryDataOperationEnum::Add,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_in, // input image Argument(const void* p_in, // input image
void* p_out, // output image void* p_out, // output image
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
: C_(C), : G_(G),
C_(C),
X_(filter_spatial_lengths[NDimSpatial - I1]), X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)}, p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)}, p_out_{static_cast<OutputDataType*>(p_out)},
...@@ -289,6 +304,9 @@ struct DeviceColumnToImageImpl ...@@ -289,6 +304,9 @@ struct DeviceColumnToImageImpl
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
const index_t x_eff = const index_t x_eff =
(filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1; (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
const index_t y_eff = const index_t y_eff =
...@@ -354,7 +372,7 @@ struct DeviceColumnToImageImpl ...@@ -354,7 +372,7 @@ struct DeviceColumnToImageImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
conv_filter_strides, conv_filter_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
independent_filters, independent_filters,
effs); effs);
const auto out_grid_desc_m_k = const auto out_grid_desc_m_k =
...@@ -387,10 +405,9 @@ struct DeviceColumnToImageImpl ...@@ -387,10 +405,9 @@ struct DeviceColumnToImageImpl
// Memory offsets to next set of independent filters, // Memory offsets to next set of independent filters,
// move to independent filters in each dimension // move to independent filters in each dimension
const index_t in_offset = const index_t in_offset =
x_idx * gemm_m_k_strides[0] + (x_idx + y_idx * output_spatial_lengths[XIdx] +
y_idx * gemm_m_k_strides[0] * output_spatial_lengths[XIdx] + z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
z_idx * gemm_m_k_strides[0] * output_spatial_lengths[YIdx] * gemm_g_m_k_strides[I1];
output_spatial_lengths[XIdx];
// Move to independent filters in appropriate dimensions // Move to independent filters in appropriate dimensions
const index_t out_offset = const index_t out_offset =
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
...@@ -417,6 +434,7 @@ struct DeviceColumnToImageImpl ...@@ -417,6 +434,7 @@ struct DeviceColumnToImageImpl
} }
} }
const ck::index_t G_;
const ck::index_t C_; const ck::index_t C_;
const ck::index_t X_; const ck::index_t X_;
...@@ -434,6 +452,8 @@ struct DeviceColumnToImageImpl ...@@ -434,6 +452,8 @@ struct DeviceColumnToImageImpl
std::vector<const InputDataType*> p_in_container_; std::vector<const InputDataType*> p_in_container_;
std::vector<OutputDataType*> p_out_container_; std::vector<OutputDataType*> p_out_container_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -451,6 +471,7 @@ struct DeviceColumnToImageImpl ...@@ -451,6 +471,7 @@ struct DeviceColumnToImageImpl
OutputGridDesc, OutputGridDesc,
OutputDataType, OutputDataType,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
GridwiseTensorRearrangeKernel>; GridwiseTensorRearrangeKernel>;
// Execute each set of independent filters // Execute each set of independent filters
...@@ -460,7 +481,7 @@ struct DeviceColumnToImageImpl ...@@ -460,7 +481,7 @@ struct DeviceColumnToImageImpl
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>( BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
arg.out_grid_desc_m_k_container_[i]); arg.out_grid_desc_m_k_container_[i]);
const index_t grid_size = const index_t grid_size =
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]); block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_;
elapsed_time += launch_and_time_kernel(stream_config, elapsed_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
...@@ -470,7 +491,9 @@ struct DeviceColumnToImageImpl ...@@ -470,7 +491,9 @@ struct DeviceColumnToImageImpl
arg.p_in_container_[i], arg.p_in_container_[i],
arg.out_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i],
arg.p_out_container_[i], arg.p_out_container_[i],
block_2_tile_map); arg.G_,
block_2_tile_map,
arg.compute_ptr_offset_of_batch_);
} }
return elapsed_time; return elapsed_time;
} }
...@@ -485,8 +508,7 @@ struct DeviceColumnToImageImpl ...@@ -485,8 +508,7 @@ struct DeviceColumnToImageImpl
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
using namespace tensor_layout::convolution; using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> || if constexpr(!(is_NSpatialGC || is_GNSpatialC))
std::is_same_v<ImageLayout, GNDHWC>))
{ {
return false; return false;
} }
...@@ -534,13 +556,14 @@ struct DeviceColumnToImageImpl ...@@ -534,13 +556,14 @@ struct DeviceColumnToImageImpl
static auto MakeArgument(const void* p_in, // input image static auto MakeArgument(const void* p_in, // input image
void* p_out, // output image void* p_out, // output image
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -548,13 +571,14 @@ struct DeviceColumnToImageImpl ...@@ -548,13 +571,14 @@ struct DeviceColumnToImageImpl
{ {
return Argument{static_cast<const InputDataType*>(p_in), return Argument{static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out), static_cast<OutputDataType*>(p_out),
G,
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -566,13 +590,14 @@ struct DeviceColumnToImageImpl ...@@ -566,13 +590,14 @@ struct DeviceColumnToImageImpl
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image MakeArgumentPointer(const void* p_in, // input image
void* p_out, // output image void* p_out, // output image
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -580,13 +605,14 @@ struct DeviceColumnToImageImpl ...@@ -580,13 +605,14 @@ struct DeviceColumnToImageImpl
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out), static_cast<OutputDataType*>(p_out),
G,
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -145,7 +145,8 @@ template <index_t NumDimM, ...@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> typename ComputeDataType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContractionMultipleD_Xdl_CShuffle struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM, : public DeviceContractionMultipleD<NumDimM,
NumDimN, NumDimN,
...@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation,
ComputeDataType>
{ {
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
...@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
......
...@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
} }
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940") else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>)) is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
...@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle" str << "DeviceGemm_Xdl_CShuffle"
<< "<" << "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
...@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: " << "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];; << PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -59,7 +59,8 @@ template <typename ADataType, ...@@ -59,7 +59,8 @@ template <typename ADataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType, typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout, BLayout,
...@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams. // TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1; static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
...@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t MPadded_, index_t MPadded_,
index_t NPadded_, index_t NPadded_,
index_t KPadded_, index_t KPadded_,
index_t K0_, index_t K0Padded_,
index_t k_batch_, index_t k_batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
...@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_, MPadded_,
NPadded_, NPadded_,
KPadded_, KPadded_,
K0_, K0Padded_,
k_batch_), k_batch_),
a_element_op(a_element_op_), a_element_op(a_element_op_),
b_element_op(b_element_op_), b_element_op(b_element_op_),
...@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto b2c_map = DefaultBlock2CTileMap{}; const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0; const auto K0Padded = karg.K0Padded;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
float ave_time = 0; float ave_time = 0;
...@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
} }
// polymorphic // polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH); const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
...@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
m_padded, m_padded,
n_padded, n_padded,
k_padded, k_padded,
k0, k0_padded,
K_BATCH}; K_BATCH};
gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(
...@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH); const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
...@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded; karg.KPadded = k_padded;
karg.K0 = k0; karg.K0Padded = k0_padded;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
...@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0; index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
K0 = karg.K0; K0 = karg.K0Padded;
bool not_all_have_main_k0_block_loop_same = bool not_all_have_main_k0_block_loop_same =
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0); all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
......
...@@ -15,15 +15,18 @@ ...@@ -15,15 +15,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" #include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Image to column for input layout NDHWC: // Image to column:
// input : input image [N, Di, Hi, Wi, C] // input : input image [G, N, Di, Hi, Wi, C]
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C] // output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C]
// input : input image [N, Di, Hi, Wi, G, C]
// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C]
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ImageLayout, typename ImageLayout,
typename InputDataType, typename InputDataType,
...@@ -41,6 +44,14 @@ struct DeviceImageToColumnImpl ...@@ -41,6 +44,14 @@ struct DeviceImageToColumnImpl
OutputDataType, OutputDataType,
conv_tensor_rearrange_op::ImageToColumn> conv_tensor_rearrange_op::ImageToColumn>
{ {
static constexpr bool is_NSpatialGC =
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
static constexpr bool is_GNSpatialC =
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -109,7 +120,7 @@ struct DeviceImageToColumnImpl ...@@ -109,7 +120,7 @@ struct DeviceImageToColumnImpl
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& gemm_m_k_strides) const std::array<index_t, 3>& gemm_g_m_k_strides)
{ {
const index_t NDoHoWo = const index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<index_t>(
...@@ -117,11 +128,10 @@ struct DeviceImageToColumnImpl ...@@ -117,11 +128,10 @@ struct DeviceImageToColumnImpl
const index_t CZYX = const index_t CZYX =
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); const auto desc_mraw_kraw = make_naive_tensor_descriptor(
return desc_m_k; make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2]));
return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
} }
using InputGridDesc = using InputGridDesc =
...@@ -132,34 +142,38 @@ struct DeviceImageToColumnImpl ...@@ -132,34 +142,38 @@ struct DeviceImageToColumnImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>( decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>; OutputGridDesc{}))>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc, using GridwiseTensorRearrangeKernel =
InputDataType, GridwiseTensorRearrange<InputGridDesc,
OutputGridDesc, InputDataType,
OutputDataType, OutputGridDesc,
BlockSize, OutputDataType,
MPerBlock, BlockSize,
KPerBlock, MPerBlock,
ThreadClusterLengths, KPerBlock,
ScalarPerVector, ThreadClusterLengths,
InMemoryDataOperationEnum::Set, ScalarPerVector,
Block2ETileMap>; InMemoryDataOperationEnum::Set,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_in, // input image Argument(const void* p_in, // input image
void* p_out, // gemm form void* p_out, // gemm form
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads)
: C_(C), : G_(G),
C_(C),
X_(filter_spatial_lengths[NDimSpatial - I1]), X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)}, p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)}, p_out_{static_cast<OutputDataType*>(p_out)},
...@@ -176,14 +190,16 @@ struct DeviceImageToColumnImpl ...@@ -176,14 +190,16 @@ struct DeviceImageToColumnImpl
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K( out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides); N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides);
compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0];
} }
void Print() const void Print() const
...@@ -192,6 +208,7 @@ struct DeviceImageToColumnImpl ...@@ -192,6 +208,7 @@ struct DeviceImageToColumnImpl
std::cout << out_grid_desc_m_k_ << std::endl; std::cout << out_grid_desc_m_k_ << std::endl;
} }
const ck::index_t G_;
const ck::index_t C_; const ck::index_t C_;
const ck::index_t X_; const ck::index_t X_;
...@@ -206,6 +223,8 @@ struct DeviceImageToColumnImpl ...@@ -206,6 +223,8 @@ struct DeviceImageToColumnImpl
InputGridDesc in_grid_desc_m_k_; InputGridDesc in_grid_desc_m_k_;
OutputGridDesc out_grid_desc_m_k_; OutputGridDesc out_grid_desc_m_k_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -220,12 +239,14 @@ struct DeviceImageToColumnImpl ...@@ -220,12 +239,14 @@ struct DeviceImageToColumnImpl
const auto block_2_tile_map = const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>( BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
arg.out_grid_desc_m_k_); arg.out_grid_desc_m_k_);
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_); const index_t grid_size =
const auto kernel = kernel_tensor_rearrange<InputGridDesc, block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_;
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
InputDataType, InputDataType,
OutputGridDesc, OutputGridDesc,
OutputDataType, OutputDataType,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
GridwiseTensorRearrangeKernel>; GridwiseTensorRearrangeKernel>;
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
...@@ -237,7 +258,9 @@ struct DeviceImageToColumnImpl ...@@ -237,7 +258,9 @@ struct DeviceImageToColumnImpl
arg.p_in_, arg.p_in_,
arg.out_grid_desc_m_k_, arg.out_grid_desc_m_k_,
arg.p_out_, arg.p_out_,
block_2_tile_map); arg.G_,
block_2_tile_map,
arg.compute_ptr_offset_of_batch_);
return elapsed_time; return elapsed_time;
} }
...@@ -250,9 +273,7 @@ struct DeviceImageToColumnImpl ...@@ -250,9 +273,7 @@ struct DeviceImageToColumnImpl
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
using namespace tensor_layout::convolution; if constexpr(!(is_NSpatialGC || is_GNSpatialC))
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{ {
return false; return false;
} }
...@@ -295,13 +316,14 @@ struct DeviceImageToColumnImpl ...@@ -295,13 +316,14 @@ struct DeviceImageToColumnImpl
static auto MakeArgument(const void* p_in, // input image static auto MakeArgument(const void* p_in, // input image
void* p_out, // gemm form void* p_out, // gemm form
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -309,13 +331,14 @@ struct DeviceImageToColumnImpl ...@@ -309,13 +331,14 @@ struct DeviceImageToColumnImpl
{ {
return Argument{static_cast<const InputDataType*>(p_in), return Argument{static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out), static_cast<OutputDataType*>(p_out),
G,
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -327,13 +350,14 @@ struct DeviceImageToColumnImpl ...@@ -327,13 +350,14 @@ struct DeviceImageToColumnImpl
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image MakeArgumentPointer(const void* p_in, // input image
void* p_out, // gemm form void* p_out, // gemm form
const ck::index_t G,
const ck::index_t N, const ck::index_t N,
const ck::index_t C, const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides, const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
...@@ -341,13 +365,14 @@ struct DeviceImageToColumnImpl ...@@ -341,13 +365,14 @@ struct DeviceImageToColumnImpl
{ {
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out), static_cast<OutputDataType*>(p_out),
G,
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
gemm_m_k_strides, gemm_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -186,6 +186,25 @@ struct Bilinear ...@@ -186,6 +186,25 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1)); y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
}; };
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x0_tmp = type_convert<float>(x0);
const float x1_tmp = type_convert<float>(x1);
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
y = type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
y = y_tmp;
};
template <> template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
......
...@@ -311,6 +311,71 @@ struct AddAddFastGelu ...@@ -311,6 +311,71 @@ struct AddAddFastGelu
} }
}; };
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
{
}
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<half_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<bhalf_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
int8_t& e, const int8_t& c, const float& d0, const float& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<int8_t>(result);
}
const float alpha1_;
const float alpha2_;
};
struct Normalize struct Normalize
{ {
// FIXME: is double absolutely necessary? // FIXME: is double absolutely necessary?
......
...@@ -16,6 +16,57 @@ namespace element_wise { ...@@ -16,6 +16,57 @@ namespace element_wise {
extern "C" __device__ float __ocml_native_recip_f32(float); extern "C" __device__ float __ocml_native_recip_f32(float);
#endif #endif
struct PassThroughPack2
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
{
// fake conversion
uint16_t t = ck::bit_cast<uint32_t>(x);
y = ck::bit_cast<ck::f8x2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
{
y = x;
}
constexpr const static bool is_pack2_invocable = true;
};
struct PassThrough struct PassThrough
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -33,6 +84,12 @@ struct PassThrough ...@@ -33,6 +84,12 @@ struct PassThrough
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -69,6 +126,12 @@ struct PassThrough ...@@ -69,6 +126,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
{
y = type_convert<float>(x);
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const __host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{ {
...@@ -228,7 +291,15 @@ struct Scale ...@@ -228,7 +291,15 @@ struct Scale
template <> template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
y = scale_ * x; y = ck::type_convert<half_t>(scale_) * x;
};
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
const float x_tmp = ck::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}; };
template <> template <>
......
...@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
index_t KPadded; index_t KPadded;
index_t K0; index_t K0Padded;
index_t k_batch; index_t k_batch;
Argument(const FloatA* p_a_grid_, Argument(const FloatA* p_a_grid_,
...@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded_, index_t MPadded_,
index_t NPadded_, index_t NPadded_,
index_t KPadded_, index_t KPadded_,
index_t K0_, index_t K0Padded_,
index_t k_batch_) index_t k_batch_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_), p_b_grid(p_b_grid_),
...@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded(MPadded_), MPadded(MPadded_),
NPadded(NPadded_), NPadded(NPadded_),
KPadded(KPadded_), KPadded(KPadded_),
K0(K0_), K0Padded(K0Padded_),
k_batch(k_batch_) k_batch(k_batch_)
{ {
} }
...@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< "MP:" << MPadded << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", " << "KP:" << KPadded << ", "
<< "K0:" << K0 << ", " << "K0Padded:" << K0Padded << ", "
<< "KB:" << k_batch << "}" << std::endl; << "KB:" << k_batch << "}" << std::endl;
} }
}; };
...@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return math::integer_least_multiple(N, NPerBlock); return math::integer_least_multiple(N, NPerBlock);
} }
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
{ {
// k_batch * k0 * k0_per_block * k1 // k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1; auto K_t = K_Batch * K0PerBlock * K1;
...@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{ {
auto K0 = CalculateK0(K, K_Batch); auto K0Padded = CalculateK0Padded(K, K_Batch);
return K_Batch * K0 * K1; return K_Batch * K0Padded * K1;
} }
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
...@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t KBatch, index_t KBatch,
index_t K0, index_t K0Padded,
index_t KPad) index_t KPad)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
...@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)), make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t N, index_t N,
index_t StrideB, index_t StrideB,
index_t KBatch, index_t KBatch,
index_t K0, index_t K0Padded,
index_t KPad) index_t KPad)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
...@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return false; return false;
} }
} }
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
...@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = karg.k_batch * K0PerBlock * K1;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
...@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout std::cout << "Arg N (" << karg.N
<< "Arg N (" << karg.N << ") value is not a multiple of "
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" "CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
...@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout std::cout << "Arg M (" << karg.M
<< "Arg M (" << karg.M << ") value is not a multiple of "
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" "CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
} }
const auto num_k_loop = karg.K0 / K0PerBlock; const auto num_k_loop = karg.K0Padded / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline." << ") value is not supported by GridwiseGemm Pipeline."
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
...@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch) __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{ {
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0Padded =
const index_t KPad = KBatch * K0 * K1; math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0Padded * K1;
return KPad; return KPad;
} }
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
{ {
const index_t num_loop = K0 / K0PerBlock; const index_t num_loop = K0Padded / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const FloatB* p_b_grid = karg.p_b_grid; const FloatB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid; FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
......
...@@ -21,6 +21,7 @@ template <typename InputGridDesc, ...@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename OutputGridDesc, typename OutputGridDesc,
typename OutputDataType, typename OutputDataType,
typename Block2ETileMap, typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch,
typename GridwiseTensorRearrangeKernel> typename GridwiseTensorRearrangeKernel>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -30,13 +31,20 @@ __global__ void ...@@ -30,13 +31,20 @@ __global__ void
const InputDataType* __restrict__ p_in_global, const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc, const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global, OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map) const index_t batch_count,
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run( GridwiseTensorRearrangeKernel::Run(in_grid_desc,
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map); p_in_global,
out_grid_desc,
p_out_global,
batch_count,
block_2_tile_map,
compute_ptr_offset_of_batch);
#else #else
ignore = in_grid_desc; ignore = in_grid_desc;
ignore = p_in_global; ignore = p_in_global;
...@@ -56,7 +64,8 @@ template <typename InputGridDesc, ...@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename ThreadClusterLengths, typename ThreadClusterLengths,
index_t ScalarPerVector, index_t ScalarPerVector,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename Block2ETileMap> typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch>
struct GridwiseTensorRearrange struct GridwiseTensorRearrange
{ {
...@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange ...@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const InputDataType* __restrict__ p_in_global, const InputDataType* __restrict__ p_in_global,
const OutputGridDesc& out_grid_desc, const OutputGridDesc& out_grid_desc,
OutputDataType* __restrict__ p_out_global, OutputDataType* __restrict__ p_out_global,
const Block2ETileMap& block_2_tile_map) const index_t batch_count,
const Block2ETileMap& block_2_tile_map,
const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
{ {
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange ...@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const index_t k_block_data_idx_on_grid = const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global = auto copy_global_to_global =
ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
Tuple<InputDataType>, Tuple<InputDataType>,
...@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange ...@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// Global Memory
const index_t a_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const index_t c_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
copy_global_to_global.Run( copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
} }
......
...@@ -100,6 +100,8 @@ template <> ...@@ -100,6 +100,8 @@ template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
{ {
float fval; float fval;
...@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) ...@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
#endif #endif
} }
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y);
}
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
......
...@@ -19,9 +19,7 @@ namespace host { ...@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image. * \brief Reference implementation for column to image.
* *
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout. * Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout. * Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout. * \tparam ImageLayout Image Layout.
...@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 && if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2)) arg.input_.GetNumOfDimension() == 3))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1]; const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2]; const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) { auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo) for(index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Wo + wo; index_t row = n * Wo + wo;
...@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{ {
float v_in = ck::type_convert<float>(arg.input_(row, column)); float v_in =
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi)); ck::type_convert<float>(arg.input_(g, row, column));
arg.output_(0, n, c, wi) = float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
} }
column++; column++;
...@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0]; const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) { auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho) for(index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(index_t wo = 0; wo < Wo; ++wo)
...@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[4]) arg.output_.GetLengths()[4])
{ {
float v_in = float v_in =
ck::type_convert<float>(arg.input_(row, column)); ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>( float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi)); arg.output_(g, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) = arg.output_(g, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
} }
column++; column++;
...@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1]; const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) { auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o) for(index_t d_o = 0; d_o < Do; ++d_o)
{ {
for(index_t ho = 0; ho < Ho; ++ho) for(index_t ho = 0; ho < Ho; ++ho)
...@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[5]) arg.output_.GetLengths()[5])
{ {
float v_in = ck::type_convert<float>( float v_in = ck::type_convert<float>(
arg.input_(row, column)); arg.input_(g, row, column));
float v_out = ck::type_convert<float>( float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi)); arg.output_(g, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) = arg.output_(g, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out); ck::type_convert<OutDataType>(v_in + v_out);
} }
column++; column++;
...@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) && if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX))) arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{ {
return false; return false;
} }
......
...@@ -23,6 +23,7 @@ template <ck::index_t NumDimM, ...@@ -23,6 +23,7 @@ template <ck::index_t NumDimM,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename ComputeDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
...@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
for(ck::index_t k1 = 0; k1 < K1; ++k1) for(ck::index_t k1 = 0; k1 < K1; ++k1)
{ {
// Simulate the possible casting when ComputeDataType is different than the
// A/B data types
ComputeDataType v_a_compute_input =
ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1));
ComputeDataType v_b_compute_input =
ck::type_convert<ComputeDataType>(arg.b_ns_ks_(n0, n1, k0, k1));
AccDataType v_a; AccDataType v_a;
AccDataType v_b; AccDataType v_b;
arg.a_element_op_( arg.a_element_op_(v_a, ck::type_convert<AccDataType>(v_a_compute_input));
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1))); arg.b_element_op_(v_b, ck::type_convert<AccDataType>(v_b_compute_input));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
} }
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc; arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert<CDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
......
...@@ -42,6 +42,7 @@ template <ck::index_t NDimSpatial, ...@@ -42,6 +42,7 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDTensor = 0,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
...@@ -57,10 +58,12 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -57,10 +58,12 @@ struct ReferenceConvFwd : public device::BaseOperator
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors)
: input_{input}, : input_{input},
weight_{weight}, weight_{weight},
output_{output}, output_{output},
d_tensors_{d_tensors},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -75,6 +78,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -75,6 +78,8 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<WeiDataType>& weight_; const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<index_t> in_left_pads_;
...@@ -129,7 +134,26 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -129,7 +134,26 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
OutDataType v_out; OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc)); OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
if constexpr(NumDTensor == 0)
{
arg.out_element_op_(v_out, v_acc_converted);
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.d_tensors_[0](g, n, k, wo),
arg.d_tensors_[1](g, n, k, wo));
}
else
{
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, wo) = v_out; arg.output_(g, n, k, wo) = v_out;
}; };
...@@ -183,7 +207,27 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -183,7 +207,27 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
OutDataType v_out; OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc)); OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
if constexpr(NumDTensor == 0)
{
arg.out_element_op_(v_out, v_acc_converted);
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, ho, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.d_tensors_[0](g, n, k, ho, wo),
arg.d_tensors_[1](g, n, k, ho, wo));
}
else
{
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, ho, wo) = v_out; arg.output_(g, n, k, ho, wo) = v_out;
}; };
...@@ -250,7 +294,27 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -250,7 +294,27 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
OutDataType v_out; OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc)); OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
if constexpr(NumDTensor == 0)
{
arg.out_element_op_(v_out, v_acc_converted);
}
else if constexpr(NumDTensor == 1)
{
arg.out_element_op_(
v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, d_o, ho, wo));
}
else if constexpr(NumDTensor == 2)
{
arg.out_element_op_(v_out,
v_acc_converted,
arg.d_tensors_[0](g, n, k, d_o, ho, wo),
arg.d_tensors_[1](g, n, k, d_o, ho, wo));
}
else
{
throw std::runtime_error("Output ElementOp not supported in reference.");
}
arg.output_(g, n, k, d_o, ho, wo) = v_out; arg.output_(g, n, k, d_o, ho, wo) = v_out;
}; };
...@@ -294,7 +358,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -294,7 +358,8 @@ struct ReferenceConvFwd : public device::BaseOperator
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op,
const std::array<Tensor<OutDataType>, NumDTensor>& d_tensors = {})
{ {
return Argument{input, return Argument{input,
weight, weight,
...@@ -305,7 +370,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -305,7 +370,8 @@ struct ReferenceConvFwd : public device::BaseOperator
input_right_pads, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op,
d_tensors};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -19,9 +19,7 @@ namespace host { ...@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for image to column. * \brief Reference implementation for image to column.
* *
* Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout. * Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C]. * Output tensor descriptor has [G * N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout. * \tparam ImageLayout Image Layout.
...@@ -95,18 +93,19 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -95,18 +93,19 @@ struct ReferenceImageToColumn : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == 2)) arg.output_.GetNumOfDimension() == 3))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1]; const index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2]; const index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n, auto wo) { auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo; index_t row = n * Wo + wo;
index_t column = 0; index_t column = 0;
...@@ -121,15 +120,15 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -121,15 +120,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
InDataType v_in = arg.input_(0, n, c, wi); InDataType v_in = arg.input_(g, n, c, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in); arg.output_(g, row, column) = ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
} }
} }
}; };
make_ParallelTensorFunctor(func, N, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -138,7 +137,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -138,7 +137,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0]; const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n, auto ho, auto wo) { auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo; index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
...@@ -162,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -162,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
InDataType v_in = arg.input_(0, n, c, hi, wi); InDataType v_in = arg.input_(g, n, c, hi, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in); arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
} }
...@@ -171,7 +171,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -171,7 +171,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N, Ho, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
return 0; return 0;
} }
...@@ -181,7 +181,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -181,7 +181,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1]; const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; index_t column = 0;
...@@ -213,8 +213,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -213,8 +213,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
InDataType v_in = arg.input_(0, n, c, di, hi, wi); InDataType v_in = arg.input_(g, n, c, di, hi, wi);
arg.output_(row, column) = arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in); ck::type_convert<OutDataType>(v_in);
} }
column++; column++;
...@@ -224,7 +224,7 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -224,7 +224,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
}; };
make_ParallelTensorFunctor(func, N, Do, Ho, Wo)( make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -267,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -267,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) && if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(CZYX))) arg.output_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{ {
return false; return false;
} }
......
...@@ -25,6 +25,8 @@ using BF8 = ck::bf8_t; ...@@ -25,6 +25,8 @@ using BF8 = ck::bf8_t;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
using BF16_Tuple = ck::Tuple<BF16>;
using F16_Tuple = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment