Commit 690b0ec9 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Create descriptor on device side

parent d73041a6
...@@ -45,58 +45,44 @@ namespace device { ...@@ -45,58 +45,44 @@ namespace device {
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
* *
*/ */
template <typename GridwiseGemm, template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, const auto a_grid_desc_k0_m_k1 = readfirstlane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
p_b_grid + b_batch_offset, karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
p_c_grid + c_batch_offset, const auto b_grid_desc_k0_n_k1 = readfirstlane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
const auto c_grid_desc_m_n = readfirstlane(GridwiseGemm::MakeCGridDescriptor_M_N(
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + c_batch_offset,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m_n); c_grid_desc_m_n);
#else #else
ignore = p_a_grid; ignore = karg;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = compute_ptr_offset_of_batch;
#endif #endif
} }
...@@ -154,93 +140,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -154,93 +140,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto a_grid_desc_k0_mp_k1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_k0_mp_k1;
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
const auto b_grid_desc_k0_np_k1 =
transform_tensor_descriptor(b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_k0_np_k1;
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
const auto c_grid_desc_mp_np = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return c_grid_desc_mp_np;
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
...@@ -272,15 +171,19 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -272,15 +171,19 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext<
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
ALayout,
BLayout,
CLayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpecialization::MNPadding,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -312,46 +215,38 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -312,46 +215,38 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
using Problem = typename GridwiseGemm::Problem;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public Problem, public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid, const BDataType* p_b_grid_,
CDataType* p_c_grid, CDataType* p_c_grid_,
index_t M, index_t M_,
index_t N, index_t N_,
index_t K, index_t K_,
index_t StrideA, index_t StrideA_,
index_t StrideB, index_t StrideB_,
index_t StrideC, index_t StrideC_,
index_t BatchStrideA, index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
index_t BatchStrideC, index_t BatchStrideC,
index_t Batch) index_t Batch_)
: p_a_grid_{p_a_grid}, : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_b_grid_{p_b_grid}, p_a_grid{p_a_grid_},
p_c_grid_{p_c_grid}, p_b_grid{p_b_grid_},
Batch_(Batch), p_c_grid{p_c_grid_},
a_grid_desc_k0_m_k1_{ Batch(Batch_),
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC}
b_grid_desc_k0_n_k1_{
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
kraw_{K}
{ {
} }
// private: const ADataType* p_a_grid;
const ADataType* p_a_grid_; const BDataType* p_b_grid;
const BDataType* p_b_grid_; CDataType* p_c_grid;
CDataType* p_c_grid_; index_t Batch;
index_t Batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -359,89 +254,39 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -359,89 +254,39 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
{ {
using Argument = DeviceBatchedGemmXdl::Argument; using Argument = DeviceBatchedGemmXdl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(stream_config.log_level_ > 0)
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) karg.Print();
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(karg))
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
} }
auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
gdx *= arg.Batch_; gdx *= karg.Batch;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0; float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{ {
const auto kernel = const auto kernel =
kernel_batched_gemm_xdlops_v2r3<GridwiseGemm, kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, true>;
ADataType, // TODO: distiguish A/B datatype
CDataType, ave_time = launch_and_time_kernel(
DeviceBatchedGemmXdl::AGridDesc_K0_M_K1, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
DeviceBatchedGemmXdl::BGridDesc_K0_N_K1,
DeviceBatchedGemmXdl::CGridDesc_M_N,
ComputePtrOffsetOfStridedBatch,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.Batch_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.compute_ptr_offset_of_batch_);
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_batched_gemm_xdlops_v2r3<GridwiseGemm, kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, false>;
ADataType, // TODO: distiguish A/B datatype
CDataType, ave_time = launch_and_time_kernel(
DeviceBatchedGemmXdl::AGridDesc_K0_M_K1, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
DeviceBatchedGemmXdl::BGridDesc_K0_N_K1,
DeviceBatchedGemmXdl::CGridDesc_M_N,
ComputePtrOffsetOfStridedBatch,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.Batch_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.compute_ptr_offset_of_batch_);
} }
return ave_time; return ave_time;
...@@ -461,15 +306,14 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -461,15 +306,14 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Problem& problem)
{ {
if(arg.kraw_ % K1 != 0) if(problem.K % K1 != 0)
{ {
return false; return false;
} }
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(problem);
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment