Commit 03cd2692 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Merge remote-tracking branch 'origin/develop' into bwroblew/warp_wise_dpp8

parents bf445c31 f5ec04f0
......@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
......
......@@ -20,7 +20,8 @@
namespace ck {
template <typename GridwiseGemm,
typename ABDataType,
typename ADataType,
typename BDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
......@@ -36,8 +37,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
......@@ -242,9 +243,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = EDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
......@@ -442,6 +447,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
BDataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
......
......@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
......
......@@ -355,6 +355,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ABDataType, // TODO: distinguish A/B datatype
ABDataType, // TODO: distinguish A/B datatype
ABDataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
......
......@@ -367,9 +367,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
......
......@@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
using ComputeDataType = ADataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -25,7 +25,7 @@ template <typename DOutDataType,
typename IndexDataType,
typename DInDataType,
ck::index_t InOutVectorSize>
struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDataType, DInDataType>
struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>
{
using DInDataType_AutomicAddPreCast =
conditional_t<is_same_v<DInDataType, float> || is_same_v<DInDataType, double>,
......@@ -91,7 +91,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
index_t dout_length,
index_t din_length,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides)
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations)
: p_dout_{p_dout},
p_indices_{p_indices},
p_din_{p_din},
......@@ -102,7 +103,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{
for(size_t i = 0; i < window_lengths.size(); ++i)
{
windowOverlap_ |= window_lengths.at(i) > window_strides.at(i);
auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1;
windowOverlap_ |= eff > window_strides.at(i);
}
}
......@@ -228,6 +230,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
}
else
{
hip_check_error(hipMemsetAsync(arg.p_din_,
0,
arg.din_length_raw_ * sizeof(DInDataType),
stream_config.stream_id_));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
InOutGrid1dDesc,
DOutDataType,
......@@ -292,7 +299,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
index_t dout_length,
index_t din_length,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides) override
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations) override
{
// Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are
// physical size of the packed tensor
......@@ -302,7 +310,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
dout_length,
din_length,
window_lengths,
window_strides);
window_strides,
window_dilations);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
......@@ -36,6 +36,13 @@ struct Add
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 + x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
......
......@@ -195,6 +195,51 @@ struct AddMultiply
}
};
// C = A * B
// E = C x D0 + D1
struct MultiplyAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c * d0) + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = type_convert<half_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = c * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
const float& c,
const float& d0,
const float& d1) const
{
const float y = c * d0 + d1;
e = y;
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
......
......@@ -39,6 +39,12 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
......
......@@ -587,7 +587,8 @@ struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
......
......@@ -15,6 +15,9 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
// GEMM:
......@@ -26,7 +29,9 @@ namespace ck {
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ABDataType, // FIXME: don't assume A/B have same datatype
template <typename ADataType,
typename BDataType,
typename ComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
......@@ -72,6 +77,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -92,15 +99,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
using ComputeDataType =
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
#else
using ABDataTypeAdjusted = ABDataType;
using ComputeDataType = ComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
......@@ -170,7 +173,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType),
sizeof(ComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
......@@ -313,8 +316,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
......@@ -332,14 +335,102 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
......@@ -408,8 +499,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataTypeAdjusted,
ADataType,
ComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
......@@ -439,8 +530,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataTypeAdjusted,
BDataType,
ComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
......@@ -470,11 +561,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ABDataTypeAdjusted,
ComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
......@@ -492,11 +583,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
......@@ -761,6 +851,85 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
});
}
}
template <bool HasMainKBlockLoop,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap>
__device__ static void Run(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton
const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};
} // namespace ck
......@@ -37,7 +37,8 @@ __global__ void
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
......
......@@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
dst_vector_container.template AsType<DstData>()(i) = v;
});
const bool is_dst_valid =
......
......@@ -116,7 +116,15 @@ struct Max
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Lowest();
}
};
__host__ __device__ static constexpr bool
......@@ -138,6 +146,15 @@ struct Max
a = b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -152,6 +169,18 @@ struct Max
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
};
struct Min
......@@ -159,6 +188,15 @@ struct Min
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Max();
}
return NumericLimits<T>::Max();
};
......@@ -181,6 +219,15 @@ struct Min
a = b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -195,6 +242,18 @@ struct Min
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
};
struct AMax
......
......@@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
CDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
......
......@@ -53,7 +53,16 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
{
int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length)
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
{
if constexpr(is_same_v<ConputeDataType, bhalf_t>)
{
float buf_val = ck::type_convert<float>(buf[index]);
buf_val += ck::type_convert<float>(arg.dout_.mData[i]);
buf[index] = ck::type_convert<ConputeDataType>(buf_val);
}
else
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
}
}
for(int i = 0; i < din_length; ++i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment