Unverified Commit 3550cefe authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into max-pool-bwd

parents 508d7a59 027e46ee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*/
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename DsGridDesc_M0_M10_M11_N0_N10_N11,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
// TODO: Enable for gfx90a after complier fix
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::Run(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
ds_grid_desc_m0_m10_m11_n0_n10_n11,
e_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1;
ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmMultipleD_Dl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
});
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
};
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
ADataType,
AccDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
EGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using DsGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
using EGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
K_(K),
Batch_(Batch),
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
e_grid_desc_m0_m10_m11_n0_n10_n11_{},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
a_grid_desc_k0_m_k1_ =
DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
});
e_grid_desc_m_n_ =
DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
ds_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_);
e_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
index_t K_;
// Batch
index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
// for calculating batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceBatchedGemmMultipleD_Dl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0),
arg.e_grid_desc_m_n_.GetLength(I1)) *
arg.Batch_;
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
const auto kernel =
kernel_gemm_dl_multiple_d<GridwiseGemm,
ADataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
DeviceOp::EGridDesc_M0_M10_M11_N0_N10_N11,
ComputePtrOffsetOfStridedBatch,
DefaultBlock2CTileMap,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.Batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.e_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// TODO: Enable for gfx90a after complier fix
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;
pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.e_grid_desc_m_n_);
return pass;
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
Batch,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t Batch,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor>& StrideDs,
index_t StrideE,
index_t BatchStrideA,
index_t BatchStrideB,
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
Batch,
StrideA,
StrideB,
StrideDs,
StrideE,
BatchStrideA,
BatchStrideB,
BatchStrideDs,
BatchStrideE,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmMultipleD_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -69,7 +69,7 @@ __global__ void ...@@ -69,7 +69,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -806,7 +806,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -806,7 +806,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -60,7 +60,7 @@ __global__ void ...@@ -60,7 +60,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -197,7 +197,8 @@ template <index_t NumDimG, ...@@ -197,7 +197,8 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, : public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
...@@ -438,7 +439,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -438,7 +439,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
D0sTransferSrcScalarPerVector>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>; using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0 pointer // D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]); p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
}); });
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
...@@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<index_t> b_nz_kz_strides_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_; std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
...@@ -716,7 +724,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -716,7 +724,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
...@@ -772,6 +781,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -772,6 +781,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return false; return false;
} }
for(int i = 0; i < NumD0Tensor; i++)
{
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
{
std::cout << "first" << std::endl;
return false;
}
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
{
std::cout << "second" << std::endl;
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -63,7 +63,7 @@ __global__ void ...@@ -63,7 +63,7 @@ __global__ void
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -614,7 +614,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -614,7 +614,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -583,7 +583,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -583,7 +583,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -56,7 +56,7 @@ __global__ void ...@@ -56,7 +56,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -171,10 +172,7 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -171,10 +172,7 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
inStridesArray_(inStridesArray), inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray), outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
blockSize_(256), blockSize_(256)
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_ * blockSize_) / 16),
num_threads_n_(16)
{ {
static_assert(NumDim_m > 0, ""); static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, ""); static_assert(NumDim_n > 0, "");
...@@ -192,34 +190,10 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -192,34 +190,10 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
return static_cast<DataType*>(out_dev_buffers[I.value]); return static_cast<DataType*>(out_dev_buffers[I.value]);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(lengths,
inStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(lengths,
outStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumOutput>{});
} }
InDataTypePointerTuple in_dev_buffers_; InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_;
InGrid2dDescTuple in_grid_2d_desc_tuple_;
OutGrid2dDescTuple out_grid_2d_desc_tuple_;
std::array<index_t, NumDim> lengths_; std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_; std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
...@@ -227,15 +201,38 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -227,15 +201,38 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
index_t blockSize_; index_t blockSize_;
index_t gridSize_;
index_t num_threads_m_;
index_t num_threads_n_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t num_threads_m = (gridSize * arg.blockSize_) / 16;
index_t num_threads_n = 16;
auto in_grid_2d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(arg.lengths_,
arg.inStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n);
},
Number<NumInput>{});
auto out_grid_2d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(arg.lengths_,
arg.outStridesArray_[I.value],
gridSize,
arg.blockSize_,
num_threads_m,
num_threads_n);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_2d<GridwiseElementwise, const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
InGrid2dDescTuple, InGrid2dDescTuple,
OutGrid2dDescTuple, OutGrid2dDescTuple,
...@@ -245,16 +242,16 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -245,16 +242,16 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
arg.in_grid_2d_desc_tuple_, in_grid_2d_desc_tuple,
arg.out_grid_2d_desc_tuple_, out_grid_2d_desc_tuple,
arg.in_dev_buffers_, arg.in_dev_buffers_,
arg.out_dev_buffers_, arg.out_dev_buffers_,
arg.elementwise_op_, arg.elementwise_op_,
arg.num_threads_m_, num_threads_m,
arg.num_threads_n_); num_threads_n);
return elapsed_time; return elapsed_time;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -144,8 +145,7 @@ struct DeviceElementwiseImpl ...@@ -144,8 +145,7 @@ struct DeviceElementwiseImpl
inStridesArray_(inStridesArray), inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray), outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
blockSize_(256), blockSize_(256)
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{ {
in_dev_buffers_ = generate_tuple( in_dev_buffers_ = generate_tuple(
[&](auto I) { [&](auto I) {
...@@ -160,26 +160,10 @@ struct DeviceElementwiseImpl ...@@ -160,26 +160,10 @@ struct DeviceElementwiseImpl
return static_cast<DataType*>(out_dev_buffers[I.value]); return static_cast<DataType*>(out_dev_buffers[I.value]);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
in_grid_1d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
lengths, inStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumInput>{});
out_grid_1d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
lengths, outStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumOutput>{});
} }
InDataTypePointerTuple in_dev_buffers_; InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_;
InGrid1dDescTuple in_grid_1d_desc_tuple_;
OutGrid1dDescTuple out_grid_1d_desc_tuple_;
std::array<index_t, NumDim> lengths_; std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_; std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
...@@ -187,13 +171,28 @@ struct DeviceElementwiseImpl ...@@ -187,13 +171,28 @@ struct DeviceElementwiseImpl
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
index_t blockSize_; index_t blockSize_;
index_t gridSize_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t gridSize = getAvailableComputeUnitCount(stream_config);
auto in_grid_1d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
},
Number<NumInput>{});
auto out_grid_1d_desc_tuple = generate_tuple(
[&](auto I) {
return MakeDescriptor_M(
arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
},
Number<NumOutput>{});
const auto kernel = kernel_elementwise_1d<GridwiseElementwise, const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
InGrid1dDescTuple, InGrid1dDescTuple,
OutGrid1dDescTuple, OutGrid1dDescTuple,
...@@ -203,11 +202,11 @@ struct DeviceElementwiseImpl ...@@ -203,11 +202,11 @@ struct DeviceElementwiseImpl
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
arg.in_grid_1d_desc_tuple_, in_grid_1d_desc_tuple,
arg.out_grid_1d_desc_tuple_, out_grid_1d_desc_tuple,
arg.in_dev_buffers_, arg.in_dev_buffers_,
arg.out_dev_buffers_, arg.out_dev_buffers_,
arg.elementwise_op_); arg.elementwise_op_);
......
...@@ -52,7 +52,7 @@ __global__ void ...@@ -52,7 +52,7 @@ __global__ void
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__)) defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
...@@ -555,7 +555,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -555,7 +555,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
index_t NRaw) index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemmWelford::template Run<HasMainKBlockLoop>( GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
...@@ -856,7 +856,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -856,7 +856,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -61,7 +61,7 @@ __global__ void ...@@ -61,7 +61,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -556,7 +556,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -556,7 +556,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -52,7 +52,7 @@ __global__ void ...@@ -52,7 +52,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -492,7 +492,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -492,7 +492,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -189,7 +189,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -189,7 +189,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -649,7 +649,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -649,7 +649,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -38,7 +38,7 @@ __global__ void ...@@ -38,7 +38,7 @@ __global__ void
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -706,7 +706,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -706,7 +706,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -131,7 +131,7 @@ __global__ void ...@@ -131,7 +131,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......
...@@ -79,7 +79,7 @@ __global__ void ...@@ -79,7 +79,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -156,7 +156,7 @@ __global__ void ...@@ -156,7 +156,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -811,7 +811,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -811,7 +811,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return false; return false;
} }
} }
else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940") else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" ||
get_device_name() == "gfx941" || get_device_name() == "gfx942")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>)) is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment