"include/ck/utility/amd_inline_asm.hpp" did not exist on "86cc678f1824076467a011bd2d3e176214f7d99c"
Commit d6ea89ec authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Add descriptor and RTC workarounds for batched_gemm_multiple_d_gemm_multiple_d.

parent d20c20a6
...@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
prob.N, prob.N,
prob.K, prob.K,
prob.O, prob.O,
x.tile_desc.gemm0_m_per_block, x.tile_desc.gemm01_m_per_block,
x.tile_desc.gemm0_n_per_block, x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_k_per_block, x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm1_n_per_block, x.tile_desc.gemm1_n_per_block,
...@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std::unordered_map<std::string, std::string> values = { std::unordered_map<std::string, std::string> values = {
{"name", {"name",
std::to_string(this->tile_desc.block_size) + "_" + std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm0_m_per_block) + "_" + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.a0k1) + "_" + std::to_string(this->tile_desc.b0k1) + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
"_" + std::to_string(this->tile_desc.b1k1) + "_" + std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" + std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" + std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
...@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))}, MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))},
{"E1Layout", ToString(this->E1.layout)}, {"E1Layout", ToString(this->E1.layout)},
{"ADataType", ToString(this->A0.element)}, {"A0DataType", ToString(this->A0.element)},
{"B0DataType", ToString(this->B0.element)}, {"B0DataType", ToString(this->B0.element)},
{"Acc0DataType", ToString(this->acc_type)}, {"Acc0DataType", ToString(this->acc_type)},
{"D0sDataType", {"D0sDataType",
...@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)}, {"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)},
{"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)}, {"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemm0k_prefetch_stage)}, {"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)}, {"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm0_m_per_block)}, {"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"A0K1", std::to_string(this->tile_desc.a0k1)}, {"A0K1", std::to_string(this->tile_desc.ak1)},
{"B0K1", std::to_string(this->tile_desc.b0k1)}, {"B0K1", std::to_string(this->tile_desc.bk1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)}, {"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#endif
#include "device_base.hpp" #include "device_base.hpp"
...@@ -31,6 +33,7 @@ template <typename A0Layout, ...@@ -31,6 +33,7 @@ template <typename A0Layout,
typename CDE1ElementwiseOperation> typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{ {
#ifndef __HIPCC_RTC__
static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size();
...@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator ...@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
CDE1ElementwiseOperation cde1_element_op) = 0; CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
} // namespace device } // namespace device
......
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -13,8 +17,6 @@ ...@@ -13,8 +17,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -350,9 +352,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -350,9 +352,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw); return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
} }
static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws, static auto MakeD0sGridDescriptor_M_N(const Array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws, const Array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride) const Array<index_t, NumD1Tensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -363,9 +365,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -363,9 +365,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Number<NumD0Tensor>{}); Number<NumD0Tensor>{});
} }
static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws, static auto MakeD1sGridDescriptor_M_N(const Array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws, const Array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride) const Array<index_t, NumD1Tensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -380,9 +382,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -380,9 +382,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0, index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s, Array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1, index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s, Array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1) index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0), : BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0), BatchStrideB0_(BatchStrideB0),
...@@ -429,9 +431,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -429,9 +431,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
private: private:
index_t BatchStrideA0_; index_t BatchStrideA0_;
index_t BatchStrideB0_; index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_; Array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_; index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_; Array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_; index_t BatchStrideE1_;
}; };
...@@ -520,6 +522,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -520,6 +522,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
B1GridDesc_N_K{}))>; B1GridDesc_N_K{}))>;
#ifndef __HIPCC_RTC__
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -790,6 +793,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -790,6 +793,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
...@@ -799,9 +803,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -799,9 +803,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// check if DsLayout is supported // check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor> template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static bool CheckDLayout() static constexpr bool CheckDLayout()
{ {
static bool valid = true; bool valid = true;
// iterate over DLayout tuple // iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
...@@ -811,13 +815,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -811,13 +815,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return valid; return valid;
} }
static bool IsSupportedArgument(const Argument& arg) static constexpr bool IsSupported()
{
if(!ck::is_xdl_supported())
{ {
return false;
}
// Check supported layouts // Check supported layouts
// A0 - Row // A0 - Row
// B0 - Col // B0 - Col
...@@ -829,16 +828,25 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -829,16 +828,25 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> && is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() && CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> || (is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
is_same_v<tensor_layout::gemm::ColumnMajor, is_same_v<tensor_layout::gemm::ColumnMajor, B1Layout>) &&
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor, CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
D1sLayout,
NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>)) is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{ {
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
return IsSupported() and GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_, arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_, arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_, arg.e1_grid_desc_m_n_,
...@@ -989,6 +997,328 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -989,6 +997,328 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return str.str(); return str.str();
} }
#endif
template <class A0Desc, class B0Desc, class D0sDesc, class B1Desc, class D1sDesc, class E1Desc>
struct Descriptor
{
// for Gemm0
template <class A0GridDescriptor>
static constexpr auto MakeA0GridDescriptor_M_K(const A0GridDescriptor& a0_grid_desc)
{
return gemm0_padder.PadADescriptor_M_K(a0_grid_desc);
}
// for Gemm0
template <class B0GridDescriptor>
static constexpr auto MakeB0GridDescriptor_N_K(const B0GridDescriptor& b0_grid_desc)
{
return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc);
}
// for Gemm0
template <class D0sGridDescriptor>
static constexpr auto MakeD0sGridDescriptor_M_N(const D0sGridDescriptor& d0s_grid_desc)
{
return transform_tuples(
[&](auto d) constexpr { return gemm0_padder.PadCDescriptor_M_N(d); },
d0s_grid_desc);
}
// for Gemm1
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_N_K(const B1GridDescriptor& b1_grid_desc)
{
return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc);
}
// for Gemm1
template <class D1sGridDescriptor>
static constexpr auto MakeD1sGridDescriptor_M_N(const D1sGridDescriptor& d1s_grid_desc)
{
return transform_tuples(
[&](auto d) constexpr { return gemm1_padder.PadCDescriptor_M_N(d); },
d1s_grid_desc);
}
// for Gemm1
template <class E1GridDescriptor>
static constexpr auto MakeE1GridDescriptor_M_N(const E1GridDescriptor& e1_grid_desc)
{
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc);
}
using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(A0Desc{}));
using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(B0Desc{}));
using D0sGridDesc_M_N = remove_cvref_t<decltype(MakeD0sGridDescriptor_M_N(D0sDesc{}))>;
using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(B1Desc{}));
using D1sGridDesc_M_N = remove_cvref_t<decltype(MakeD1sGridDescriptor_M_N(D1sDesc{}))>;
using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(E1Desc{}));
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
A0DataType, // TODO: distinguish A/B datatype
Acc0DataType,
D0sDataType,
Acc1DataType,
C1ShuffleDataType,
D1sDataType,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
A0GridDesc_M_K,
B0GridDesc_N_K,
D0sGridDesc_M_N,
B1GridDesc_N_K,
D1sGridDesc_M_N,
E1GridDesc_M_N,
NumGemm0KPrefetchStage,
BlockSize,
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
A0K1,
B0K1,
B1K1,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
Gemm1NXdlPerWave,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0BlockTransferSrcAccessOrder,
A0BlockTransferSrcVectorDim,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
true,
A0BlockLdsExtraM,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
B0BlockTransferSrcAccessOrder,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
true,
B0BlockLdsExtraN,
CDE0BlockTransferSrcVectorDim,
CDE0BlockTransferSrcScalaerPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
C1ShuffleMXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using A0GridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(
A0GridDesc_M_K{}))>;
using B0GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(
B0GridDesc_N_K{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
B1GridDesc_N_K{}))>;
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<
decltype(GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
// tensor descriptors for problem definiton
A0GridDesc_M_K a0_grid_desc_m_k;
B0GridDesc_N_K b0_grid_desc_n_k;
D0sGridDesc_M_N d0s_grid_desc_m_n;
B1GridDesc_N_K b1_grid_desc_n_k;
D1sGridDesc_M_N d1s_grid_desc_m_n;
E1GridDesc_M_N e1_grid_desc_m_n;
// tensor descriptors for block/thread-wise copy
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1;
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1;
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock;
// block-to-e1-tile map
DefaultBlock2E1TileMap block_2_e1tile_map;
// element-wise op
A0ElementwiseOperation a0_element_op;
B0ElementwiseOperation b0_element_op;
CDE0ElementwiseOperation cde0_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
bool has_main_k_block_loop = true;
constexpr Descriptor(A0Desc a0,
B0Desc b0,
D0sDesc d0s,
B1Desc b1,
D1sDesc d1s,
E1Desc e1,
A0ElementwiseOperation a0_element_op_,
B0ElementwiseOperation b0_element_op_,
CDE0ElementwiseOperation cde0_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: a0_grid_desc_m_k{MakeA0GridDescriptor_M_K(a0)},
b0_grid_desc_n_k{MakeB0GridDescriptor_N_K(b0)},
d0s_grid_desc_m_n{MakeD0sGridDescriptor_M_N(d0s)},
b1_grid_desc_n_k{MakeB1GridDescriptor_N_K(b1)},
d1s_grid_desc_m_n{MakeD1sGridDescriptor_M_N(d1s)},
e1_grid_desc_m_n{MakeE1GridDescriptor_M_N(e1)},
a0_grid_desc_ak0_m_ak1{
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k)},
b0_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k)},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5{
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n)},
b1_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k)},
d1s_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc_m_n)},
e1_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n)},
block_2_e1tile_map{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n)},
a0_element_op{a0_element_op_},
b0_element_op{b0_element_op_},
cde0_element_op{cde0_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
has_main_k_block_loop{
GridwiseGemm::CalculateHasMainKBlockLoop(a0_grid_desc_m_k.GetLength(I1))}
{
}
constexpr bool IsValid() const
{
return IsSupported() and GridwiseGemm::CheckValidity(a0_grid_desc_m_k,
b0_grid_desc_n_k,
b1_grid_desc_n_k,
e1_grid_desc_m_n,
block_2_e1tile_map);
}
};
template <class A0Desc, class B0Desc, class D0sDesc, class B1Desc, class D1sDesc, class E1Desc>
static constexpr auto
make_descriptor(A0Desc a0,
B0Desc b0,
D0sDesc d0s,
B1Desc b1,
D1sDesc d1s,
E1Desc e1,
A0ElementwiseOperation a0_element_op = A0ElementwiseOperation{},
B0ElementwiseOperation b0_element_op = B0ElementwiseOperation{},
CDE0ElementwiseOperation cde0_element_op = CDE0ElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CDE1ElementwiseOperation cde1_element_op = CDE1ElementwiseOperation{})
{
return Descriptor<A0Desc, B0Desc, D0sDesc, B1Desc, D1sDesc, E1Desc>(a0,
b0,
d0s,
b1,
d1s,
e1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
}
template <class Desc, class D0sPointer, class D1sPointer>
__device__ static void Run(const Desc& desc,
const A0DataType* __restrict__ p_a0_grid,
const B0DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid,
const B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.IsValid());
#endif
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a0_grid,
p_b0_grid,
p_d0s_grid,
p_b1_grid,
p_d1s_grid,
p_e1_grid,
p_shared_block,
desc.a0_element_op,
desc.b0_element_op,
desc.cde0_element_op,
desc.b1_element_op,
desc.cde1_element_op,
desc.a0_grid_desc_ak0_m_ak1,
desc.b0_grid_desc_bk0_n_bk1,
desc.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
desc.b1_grid_desc_bk0_n_bk1,
desc.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e1_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_e1tile_map);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a0_grid,
p_b0_grid,
p_d0s_grid,
p_b1_grid,
p_d1s_grid,
p_e1_grid,
p_shared_block,
desc.a0_element_op,
desc.b0_element_op,
desc.cde0_element_op,
desc.b1_element_op,
desc.cde1_element_op,
desc.a0_grid_desc_ak0_m_ak1,
desc.b0_grid_desc_bk0_n_bk1,
desc.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
desc.b1_grid_desc_bk0_n_bk1,
desc.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e1_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_e1tile_map);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return false; return false;
} }
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n)) // if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{ // {
return false; // return false;
} // }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
...@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
else else
{ {
static_for<0, acc0_thread_buf.Size(), 1>{}( static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); }); [&](auto i) { cde0_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
} }
// gemm1 // gemm1
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment