Commit d6ea89ec authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Add descriptor and RTC workarounds for batched_gemm_multiple_d_gemm_multiple_d.

parent d20c20a6
......@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
prob.N,
prob.K,
prob.O,
x.tile_desc.gemm0_m_per_block,
x.tile_desc.gemm01_m_per_block,
x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm1_n_per_block,
......@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std::unordered_map<std::string, std::string> values = {
{"name",
std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm0_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.a0k1) + "_" + std::to_string(this->tile_desc.b0k1) +
"_" + std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
......@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))},
{"E1Layout", ToString(this->E1.layout)},
{"ADataType", ToString(this->A0.element)},
{"A0DataType", ToString(this->A0.element)},
{"B0DataType", ToString(this->B0.element)},
{"Acc0DataType", ToString(this->acc_type)},
{"D0sDataType",
......@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)},
{"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemm0k_prefetch_stage)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm0_m_per_block)},
{"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"A0K1", std::to_string(this->tile_desc.a0k1)},
{"B0K1", std::to_string(this->tile_desc.b0k1)},
{"A0K1", std::to_string(this->tile_desc.ak1)},
{"B0K1", std::to_string(this->tile_desc.bk1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
......
......@@ -3,8 +3,10 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <vector>
#endif
#include "device_base.hpp"
......@@ -31,6 +33,7 @@ template <typename A0Layout,
typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{
#ifndef __HIPCC_RTC__
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
......@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
} // namespace device
......
......@@ -3,8 +3,12 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -13,8 +17,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -350,9 +352,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
}
static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride)
static auto MakeD0sGridDescriptor_M_N(const Array<index_t, NumD1Tensor>& MRaws,
const Array<index_t, NumD1Tensor>& NRaws,
const Array<index_t, NumD1Tensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
......@@ -363,9 +365,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Number<NumD0Tensor>{});
}
static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
const std::array<index_t, NumD1Tensor>& NRaws,
const std::array<index_t, NumD1Tensor>& DsStride)
static auto MakeD1sGridDescriptor_M_N(const Array<index_t, NumD1Tensor>& MRaws,
const Array<index_t, NumD1Tensor>& NRaws,
const Array<index_t, NumD1Tensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
......@@ -380,9 +382,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
Array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
Array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
......@@ -429,9 +431,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
private:
index_t BatchStrideA0_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
Array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
Array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_;
};
......@@ -520,6 +522,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
B1GridDesc_N_K{}))>;
#ifndef __HIPCC_RTC__
// Argument
struct Argument : public BaseArgument
{
......@@ -790,6 +793,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
......@@ -799,9 +803,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static bool CheckDLayout()
static constexpr bool CheckDLayout()
{
static bool valid = true;
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
......@@ -811,13 +815,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return valid;
}
static bool IsSupportedArgument(const Argument& arg)
static constexpr bool IsSupported()
{
if(!ck::is_xdl_supported())
{
return false;
}
// Check supported layouts
// A0 - Row
// B0 - Col
......@@ -829,20 +828,29 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
is_same_v<tensor_layout::gemm::ColumnMajor,
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
D1sLayout,
NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::ColumnMajor, B1Layout>) &&
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_,
arg.block_2_e1tile_map_);
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
return IsSupported() and GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_,
arg.e1_grid_desc_m_n_,
arg.block_2_e1tile_map_);
}
// polymorphic
......@@ -989,6 +997,328 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return str.str();
}
#endif
template <class A0Desc, class B0Desc, class D0sDesc, class B1Desc, class D1sDesc, class E1Desc>
struct Descriptor
{
// for Gemm0
template <class A0GridDescriptor>
static constexpr auto MakeA0GridDescriptor_M_K(const A0GridDescriptor& a0_grid_desc)
{
return gemm0_padder.PadADescriptor_M_K(a0_grid_desc);
}
// for Gemm0
template <class B0GridDescriptor>
static constexpr auto MakeB0GridDescriptor_N_K(const B0GridDescriptor& b0_grid_desc)
{
return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc);
}
// for Gemm0
template <class D0sGridDescriptor>
static constexpr auto MakeD0sGridDescriptor_M_N(const D0sGridDescriptor& d0s_grid_desc)
{
return transform_tuples(
[&](auto d) constexpr { return gemm0_padder.PadCDescriptor_M_N(d); },
d0s_grid_desc);
}
// for Gemm1
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_N_K(const B1GridDescriptor& b1_grid_desc)
{
return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc);
}
// for Gemm1
template <class D1sGridDescriptor>
static constexpr auto MakeD1sGridDescriptor_M_N(const D1sGridDescriptor& d1s_grid_desc)
{
return transform_tuples(
[&](auto d) constexpr { return gemm1_padder.PadCDescriptor_M_N(d); },
d1s_grid_desc);
}
// for Gemm1
template <class E1GridDescriptor>
static constexpr auto MakeE1GridDescriptor_M_N(const E1GridDescriptor& e1_grid_desc)
{
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc);
}
using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(A0Desc{}));
using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(B0Desc{}));
using D0sGridDesc_M_N = remove_cvref_t<decltype(MakeD0sGridDescriptor_M_N(D0sDesc{}))>;
using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(B1Desc{}));
using D1sGridDesc_M_N = remove_cvref_t<decltype(MakeD1sGridDescriptor_M_N(D1sDesc{}))>;
using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(E1Desc{}));
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
A0DataType, // TODO: distinguish A/B datatype
Acc0DataType,
D0sDataType,
Acc1DataType,
C1ShuffleDataType,
D1sDataType,
E1DataType,
A0ElementwiseOperation,
B0ElementwiseOperation,
CDE0ElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
A0GridDesc_M_K,
B0GridDesc_N_K,
D0sGridDesc_M_N,
B1GridDesc_N_K,
D1sGridDesc_M_N,
E1GridDesc_M_N,
NumGemm0KPrefetchStage,
BlockSize,
Gemm0MPerBlock,
Gemm0NPerBlock,
Gemm0KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
A0K1,
B0K1,
B1K1,
Gemm0MPerXdl,
Gemm0NPerXdl,
Gemm0MXdlPerWave,
Gemm0NXdlPerWave,
Gemm1NXdlPerWave,
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
A0BlockTransferThreadClusterArrangeOrder,
A0BlockTransferSrcAccessOrder,
A0BlockTransferSrcVectorDim,
A0BlockTransferSrcScalarPerVector,
A0BlockTransferDstScalarPerVector_AK1,
true,
A0BlockLdsExtraM,
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
B0BlockTransferThreadClusterArrangeOrder,
B0BlockTransferSrcAccessOrder,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_BK1,
true,
B0BlockLdsExtraN,
CDE0BlockTransferSrcVectorDim,
CDE0BlockTransferSrcScalaerPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
C1ShuffleMXdlPerWavePerShuffle,
C1ShuffleGemm0NXdlPerWavePerShuffle,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using A0GridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(
A0GridDesc_M_K{}))>;
using B0GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(
B0GridDesc_N_K{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
B1GridDesc_N_K{}))>;
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<
decltype(GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
// tensor descriptors for problem definiton
A0GridDesc_M_K a0_grid_desc_m_k;
B0GridDesc_N_K b0_grid_desc_n_k;
D0sGridDesc_M_N d0s_grid_desc_m_n;
B1GridDesc_N_K b1_grid_desc_n_k;
D1sGridDesc_M_N d1s_grid_desc_m_n;
E1GridDesc_M_N e1_grid_desc_m_n;
// tensor descriptors for block/thread-wise copy
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1;
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1;
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock;
// block-to-e1-tile map
DefaultBlock2E1TileMap block_2_e1tile_map;
// element-wise op
A0ElementwiseOperation a0_element_op;
B0ElementwiseOperation b0_element_op;
CDE0ElementwiseOperation cde0_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
bool has_main_k_block_loop = true;
constexpr Descriptor(A0Desc a0,
B0Desc b0,
D0sDesc d0s,
B1Desc b1,
D1sDesc d1s,
E1Desc e1,
A0ElementwiseOperation a0_element_op_,
B0ElementwiseOperation b0_element_op_,
CDE0ElementwiseOperation cde0_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: a0_grid_desc_m_k{MakeA0GridDescriptor_M_K(a0)},
b0_grid_desc_n_k{MakeB0GridDescriptor_N_K(b0)},
d0s_grid_desc_m_n{MakeD0sGridDescriptor_M_N(d0s)},
b1_grid_desc_n_k{MakeB1GridDescriptor_N_K(b1)},
d1s_grid_desc_m_n{MakeD1sGridDescriptor_M_N(d1s)},
e1_grid_desc_m_n{MakeE1GridDescriptor_M_N(e1)},
a0_grid_desc_ak0_m_ak1{
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k)},
b0_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k)},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5{
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n)},
b1_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k)},
d1s_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc_m_n)},
e1_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n)},
block_2_e1tile_map{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n)},
a0_element_op{a0_element_op_},
b0_element_op{b0_element_op_},
cde0_element_op{cde0_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
has_main_k_block_loop{
GridwiseGemm::CalculateHasMainKBlockLoop(a0_grid_desc_m_k.GetLength(I1))}
{
}
constexpr bool IsValid() const
{
return IsSupported() and GridwiseGemm::CheckValidity(a0_grid_desc_m_k,
b0_grid_desc_n_k,
b1_grid_desc_n_k,
e1_grid_desc_m_n,
block_2_e1tile_map);
}
};
template <class A0Desc, class B0Desc, class D0sDesc, class B1Desc, class D1sDesc, class E1Desc>
static constexpr auto
make_descriptor(A0Desc a0,
B0Desc b0,
D0sDesc d0s,
B1Desc b1,
D1sDesc d1s,
E1Desc e1,
A0ElementwiseOperation a0_element_op = A0ElementwiseOperation{},
B0ElementwiseOperation b0_element_op = B0ElementwiseOperation{},
CDE0ElementwiseOperation cde0_element_op = CDE0ElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CDE1ElementwiseOperation cde1_element_op = CDE1ElementwiseOperation{})
{
return Descriptor<A0Desc, B0Desc, D0sDesc, B1Desc, D1sDesc, E1Desc>(a0,
b0,
d0s,
b1,
d1s,
e1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
}
template <class Desc, class D0sPointer, class D1sPointer>
__device__ static void Run(const Desc& desc,
const A0DataType* __restrict__ p_a0_grid,
const B0DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid,
const B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.IsValid());
#endif
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a0_grid,
p_b0_grid,
p_d0s_grid,
p_b1_grid,
p_d1s_grid,
p_e1_grid,
p_shared_block,
desc.a0_element_op,
desc.b0_element_op,
desc.cde0_element_op,
desc.b1_element_op,
desc.cde1_element_op,
desc.a0_grid_desc_ak0_m_ak1,
desc.b0_grid_desc_bk0_n_bk1,
desc.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
desc.b1_grid_desc_bk0_n_bk1,
desc.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e1_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_e1tile_map);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a0_grid,
p_b0_grid,
p_d0s_grid,
p_b1_grid,
p_d1s_grid,
p_e1_grid,
p_shared_block,
desc.a0_element_op,
desc.b0_element_op,
desc.cde0_element_op,
desc.b1_element_op,
desc.cde1_element_op,
desc.a0_grid_desc_ak0_m_ak1,
desc.b0_grid_desc_bk0_n_bk1,
desc.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
desc.b1_grid_desc_bk0_n_bk1,
desc.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e1_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_e1tile_map);
}
}
};
} // namespace device
......
......@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return false;
}
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
return false;
}
// if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
......@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); });
[&](auto i) { cde0_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
}
// gemm1
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment