Unverified Commit 70eebf22 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into grouped_gemm_multi_abd_fixed_nk_example

parents 5608328c 98fd41f5
......@@ -145,7 +145,8 @@ template <index_t NumDimM,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
typename ComputeDataType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceContractionMultipleD_Xdl_CShuffle
: public DeviceContractionMultipleD<NumDimM,
NumDimN,
......@@ -156,7 +157,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
......@@ -310,8 +312,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......
......@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseImpl<" ;
str << "NumDim_" << NumDim << ",";
str << "MPerThread_" << MPerThread << ",";
str << "InScalarPerVector";
static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; });
str << ",";
str << "OutScalarPerVector";
static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; });
str << ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
......
......@@ -184,7 +184,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false;
}
}
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")
else if(ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
......@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off
str << "DeviceGemm_Xdl_CShuffle"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
......@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];;
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
......
......@@ -59,7 +59,8 @@ template <typename ADataType,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout,
......@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
......@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t K0Padded_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
......@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_,
NPadded_,
KPadded_,
K0_,
K0Padded_,
k_batch_),
a_element_op(a_element_op_),
b_element_op(b_element_op_),
......@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0;
const auto K0Padded = karg.K0Padded;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded);
float ave_time = 0;
......@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch,
a_element_op,
b_element_op,
......@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch),
GridwiseGemm::CalculateK0Padded(K, KBatch),
KBatch,
a_element_op,
b_element_op,
......@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched]
<< ", PipelineVersion: " << PipelineVersionToString[PipelineVer];
return str.str();
}
};
} // namespace device
......
......@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
......@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
m_padded,
n_padded,
k_padded,
k0,
k0_padded,
K_BATCH};
gemm_kernel_args_.emplace_back(
......@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
......@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0;
karg.K0Padded = k0_padded;
karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
......@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
......@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str());
}
K0 = karg.K0;
K0 = karg.K0Padded;
bool not_all_have_main_k0_block_loop_same =
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
......
......@@ -15,15 +15,18 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Image to column for input layout NDHWC:
// input : input image [N, Di, Hi, Wi, C]
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
// Image to column:
// input : input image [G, N, Di, Hi, Wi, C]
// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C]
// input : input image [N, Di, Hi, Wi, G, C]
// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C]
template <index_t NDimSpatial,
typename ImageLayout,
typename InputDataType,
......@@ -41,6 +44,14 @@ struct DeviceImageToColumnImpl
OutputDataType,
conv_tensor_rearrange_op::ImageToColumn>
{
static constexpr bool is_NSpatialGC =
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
static constexpr bool is_GNSpatialC =
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -109,7 +120,7 @@ struct DeviceImageToColumnImpl
const ck::index_t C,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, 2>& gemm_m_k_strides)
const std::array<index_t, 3>& gemm_g_m_k_strides)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
......@@ -117,11 +128,10 @@ struct DeviceImageToColumnImpl
const index_t CZYX =
C * ck::accumulate_n<index_t>(
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
return desc_m_k;
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2]));
return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
}
using InputGridDesc =
......@@ -132,34 +142,38 @@ struct DeviceImageToColumnImpl
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
OutputGridDesc{}))>;
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
Block2ETileMap>;
using GridwiseTensorRearrangeKernel =
GridwiseTensorRearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
BlockSize,
MPerBlock,
KPerBlock,
ThreadClusterLengths,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>>;
struct Argument : public BaseArgument
{
Argument(const void* p_in, // input image
void* p_out, // gemm form
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
: C_(C),
: G_(G),
C_(C),
X_(filter_spatial_lengths[NDimSpatial - I1]),
p_in_{static_cast<const InputDataType*>(p_in)},
p_out_{static_cast<OutputDataType*>(p_out)},
......@@ -176,14 +190,16 @@ struct DeviceImageToColumnImpl
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides);
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides);
compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0];
compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0];
}
void Print() const
......@@ -192,6 +208,7 @@ struct DeviceImageToColumnImpl
std::cout << out_grid_desc_m_k_ << std::endl;
}
const ck::index_t G_;
const ck::index_t C_;
const ck::index_t X_;
......@@ -206,6 +223,8 @@ struct DeviceImageToColumnImpl
InputGridDesc in_grid_desc_m_k_;
OutputGridDesc out_grid_desc_m_k_;
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
};
struct Invoker : public BaseInvoker
......@@ -220,12 +239,14 @@ struct DeviceImageToColumnImpl
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
arg.out_grid_desc_m_k_);
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
const index_t grid_size =
block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_;
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
InputDataType,
OutputGridDesc,
OutputDataType,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I0>,
GridwiseTensorRearrangeKernel>;
float elapsed_time = launch_and_time_kernel(stream_config,
......@@ -237,7 +258,9 @@ struct DeviceImageToColumnImpl
arg.p_in_,
arg.out_grid_desc_m_k_,
arg.p_out_,
block_2_tile_map);
arg.G_,
block_2_tile_map,
arg.compute_ptr_offset_of_batch_);
return elapsed_time;
}
......@@ -250,9 +273,7 @@ struct DeviceImageToColumnImpl
bool IsSupportedArgument(const Argument& arg)
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
if constexpr(!(is_NSpatialGC || is_GNSpatialC))
{
return false;
}
......@@ -295,13 +316,14 @@ struct DeviceImageToColumnImpl
static auto MakeArgument(const void* p_in, // input image
void* p_out, // gemm form
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -309,13 +331,14 @@ struct DeviceImageToColumnImpl
{
return Argument{static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out),
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_m_k_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -327,13 +350,14 @@ struct DeviceImageToColumnImpl
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, // input image
void* p_out, // gemm form
const ck::index_t G,
const ck::index_t N,
const ck::index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, 3>& gemm_g_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......@@ -341,13 +365,14 @@ struct DeviceImageToColumnImpl
{
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
static_cast<OutputDataType*>(p_out),
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_m_k_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
......@@ -186,6 +186,25 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x0_tmp = type_convert<float>(x0);
const float x1_tmp = type_convert<float>(x1);
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
y = type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
y = y_tmp;
};
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
......
......@@ -311,6 +311,71 @@ struct AddAddFastGelu
}
};
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
{
}
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<half_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<bhalf_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
int8_t& e, const int8_t& c, const float& d0, const float& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<int8_t>(result);
}
const float alpha1_;
const float alpha2_;
};
struct Normalize
{
// FIXME: is double absolutely necessary?
......
......@@ -16,6 +16,57 @@ namespace element_wise {
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThroughPack2
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
{
// fake conversion
uint16_t t = ck::bit_cast<uint32_t>(x);
y = ck::bit_cast<ck::f8x2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
{
y = x;
}
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
{
y = x;
}
constexpr const static bool is_pack2_invocable = true;
};
struct PassThrough
{
template <typename Y, typename X>
......@@ -33,6 +84,12 @@ struct PassThrough
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -69,6 +126,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
......@@ -207,7 +270,8 @@ struct ConvertF8SR
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!");
static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
......@@ -224,6 +288,20 @@ struct Scale
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = ck::type_convert<half_t>(scale_) * x;
};
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
const float x_tmp = ck::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -442,10 +520,11 @@ struct Sigmoid
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = 1 / (ck::type_convert<T>(1) + exp(-x));
constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x));
};
};
......@@ -455,7 +534,8 @@ struct TanH
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck::math::tanh(x);
......@@ -481,7 +561,101 @@ struct Swish
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
};
float beta_ = 1.0f;
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
const float alpha_;
};
} // namespace element_wise
......
......@@ -136,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0;
index_t K0Padded;
index_t k_batch;
Argument(const FloatA* p_a_grid_,
......@@ -151,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t K0Padded_,
index_t k_batch_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
......@@ -165,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_),
K0Padded(K0Padded_),
k_batch(k_batch_)
{
}
......@@ -182,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", "
<< "K0Padded:" << K0Padded << ", "
<< "KB:" << k_batch << "}" << std::endl;
}
};
......@@ -205,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
__host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
......@@ -214,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
auto K0Padded = CalculateK0Padded(K, K_Batch);
return K_Batch * K0Padded * K1;
}
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
......@@ -223,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t K0Padded,
index_t KPad)
{
const auto a_grid_desc_m_k = [&]() {
......@@ -237,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -259,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -272,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t K0Padded,
index_t KPad)
{
const auto b_grid_desc_k_n = [&]() {
......@@ -286,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -308,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
......@@ -398,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
......@@ -410,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = karg.k_batch * K0PerBlock * K1;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
......@@ -478,11 +522,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{
#if DEBUG_LOG
std::cout
<< "Arg N (" << karg.N
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
......@@ -493,25 +537,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{
#if DEBUG_LOG
std::cout
<< "Arg M (" << karg.M
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
const auto num_k_loop = karg.K0 / K0PerBlock;
const auto num_k_loop = karg.K0Padded / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
#if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline."
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
<< " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
......@@ -521,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
{
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1;
const index_t K0Padded =
math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0Padded * K1;
return KPad;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
{
const index_t num_loop = K0 / K0PerBlock;
const index_t num_loop = K0Padded / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -595,9 +640,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const FloatB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
......
......@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch,
typename GridwiseTensorRearrangeKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -30,13 +31,20 @@ __global__ void
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
const index_t batch_count,
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global,
out_grid_desc,
p_out_global,
batch_count,
block_2_tile_map,
compute_ptr_offset_of_batch);
#else
ignore = in_grid_desc;
ignore = p_in_global;
......@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename ThreadClusterLengths,
index_t ScalarPerVector,
InMemoryDataOperationEnum DstInMemOp,
typename Block2ETileMap>
typename Block2ETileMap,
typename ComputePtrOffsetOfStridedBatch>
struct GridwiseTensorRearrange
{
......@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc& out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap& block_2_tile_map)
const index_t batch_count,
const Block2ETileMap& block_2_tile_map,
const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
{
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global =
ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
Tuple<InputDataType>,
......@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}};
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// Global Memory
const index_t a_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const index_t c_batch_offset =
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
}
......
......@@ -1075,6 +1075,7 @@ struct NumericUtils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
......@@ -1091,6 +1092,7 @@ struct NumericUtils<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
......@@ -1107,6 +1109,8 @@ struct NumericUtils<f8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 8; // negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
template <>
......@@ -1114,6 +1118,7 @@ struct NumericUtils<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
//
} // namespace ck
......@@ -5,7 +5,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
namespace ck {
// fp8 rounding modes
......@@ -17,6 +16,9 @@ enum class f8_rounding_mode
stochastic
};
__host__ inline int clz(uint32_t x) { return __builtin_clz(x); }
__device__ inline int clz(uint32_t x) { return __clz(x); }
} // namespace ck
namespace ck::utils {
......@@ -34,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr int in_exp = NumericUtils<X>::exp;
constexpr int in_mant = NumericUtils<X>::mant;
int exponent;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = 0x80;
......@@ -49,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa = x_bitwise & NumericUtils<X>::mant_mask;
exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = NumericUtils<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff =
(1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
......@@ -67,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// if input is half and output is bf8
if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
exponent == 0)
{
exponent += 1;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent -= 1;
}
mantissa &= ~(1 << in_mant);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
mantissa += 1 << in_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << in_mant))
{
mantissa >>= 1;
exponent++;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
}
mantissa >>= (in_mant - out_mant);
// check negative exponent
if(exponent <= 0)
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
bool odd =
mantissa &
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(out_exponent == 0)
{
if(x_bitwise == 0)
return 0;
else
if((1 << in_mant) & mantissa)
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa >>= 1 - exponent;
exponent = 0;
out_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
// above range: quantize to maximum possible float of the same sign
else if(exponent > max_exp)
else
{
if((1 << (in_mant + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (in_mant - out_mant);
if(out_exponent > max_exp)
{
if(clip)
{
mantissa = (1 << out_mant) - 1;
exponent = max_exp;
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
}
else
{
......@@ -125,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
if(out_exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
mantissa &= (1 << out_mant) - 1;
return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
}
template <typename X, typename Y, bool negative_zero_nan>
......@@ -194,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent++;
while(mantissa < (1 << in_mant))
{
mantissa <<= 1;
exponent--;
}
int sh = 1 + clz(mantissa) - (32 - in_mant);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
......
......@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c = __builtin_amdgcn_sudot4(true, bit_cast<int32_t>(a), true, bit_cast<int32_t>(b), c, false);
#else
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
......
......@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
static inline __host__ float exp(float x) { return std::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
......
......@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck {
namespace math {
......@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x)
template <typename T>
inline __host__ T tanh(T x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
return ck::type_convert<T>(std::tanhf(ck::type_convert<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); };
template <>
inline __host__ float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
inline __host__ double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
inline __host__ T exp(T x)
{
return ck::type_convert<T>(std::expf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float exp<float>(float x)
{
return std::expf(x);
}
static inline __host__ double tanh(double x) { return std::tanh(x); };
template <>
inline __host__ double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
inline __host__ T log(T x)
{
return ck::type_convert<T>(std::logf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float log<float>(float x)
{
return std::logf(x);
}
template <>
inline __host__ double log<double>(double x)
{
return std::log(x);
}
template <typename T>
inline __host__ T pow(T x, T gamma)
{
return ck::type_convert<T>(
std::powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
}
template <>
inline __host__ float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
inline __host__ double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
inline __host__ T expm1(T x)
{
return ck::type_convert<T>(std::expm1f(ck::type_convert<float>(x)));
}
template <>
inline __host__ float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
......@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x)
template <typename T>
inline __device__ T tanh(T x)
{
return ck::type_convert<T>(::tanhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float tanh<float>(float x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
return ::tanhf(x);
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
template <>
inline __device__ double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t exp<half_t>(half_t x)
{
return hexp(x);
};
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
};
static inline __device__ double tanh(double x) { return ::tanh(x); };
template <>
inline __device__ double exp<double>(double x)
{
return exp(x);
};
template <typename T>
inline __device__ T log(T x)
{
return ck::type_convert<T>(__logf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t log<half_t>(half_t x)
{
return hlog(x);
};
template <>
inline __device__ float log<float>(float x)
{
return __logf(x);
};
template <>
inline __device__ double log<double>(double x)
{
return log(x);
};
template <typename T>
inline __device__ T pow(T x, T gamma)
{
return ck::type_convert<T>(powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
};
template <>
inline __device__ float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
inline __device__ double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
inline __device__ T expm1(T x)
{
return ck::type_convert<T>(expm1f(ck::type_convert<float>(x)));
};
template <>
inline __device__ float expm1<float>(float x)
{
return expm1f(x);
};
template <>
inline __device__ double expm1<double>(double x)
{
return expm1(x);
};
} // namespace math
} // namespace ck
......@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace ck {
......
......@@ -100,6 +100,8 @@ template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
......@@ -138,6 +140,36 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y);
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
......@@ -145,7 +177,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -153,8 +185,6 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
#endif
}
......@@ -165,11 +195,9 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
......@@ -223,7 +251,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
......@@ -231,8 +259,6 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
#endif
}
......@@ -243,11 +269,9 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif
}
......@@ -347,7 +371,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
......@@ -356,8 +380,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return f8_convert_sr<f8_t>(type_convert<float>(x));
#endif
}
......@@ -396,7 +418,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#elif 0
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
......@@ -406,8 +428,6 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#else
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#endif
}
......
......@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
......@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float Run(const Argument& arg)
{
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2))
arg.input_.GetNumOfDimension() == 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
......@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{
float v_in = ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi));
arg.output_(0, n, c, wi) =
float v_in =
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
......@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
......@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
......@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[4])
{
float v_in =
ck::type_convert<float>(arg.input_(row, column));
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) =
arg.output_(g, n, c, hi, wi));
arg.output_(g, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
......@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
......@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
......@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[5])
{
float v_in = ck::type_convert<float>(
arg.input_(row, column));
arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) =
arg.output_(g, n, c, di, hi, wi));
arg.output_(g, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
......@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
......@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment