Commit f0d63f25 authored by wangshaojie6's avatar wangshaojie6
Browse files

add some code

parent 7c7364a6
...@@ -30,7 +30,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -30,7 +30,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); static_assert(sizeof(BDataType) == sizeof(KernelBDataType));
#endif #endif
auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size; auto& [M, N, K, StrideA, StrideB, StrideE, KBatch] = problem_size;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -48,21 +48,36 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -48,21 +48,36 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
// A[M0, M1, K0, K1] Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::vector<ck::index_t> a_ms_ks_lengths{M, KBatch, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
auto f_tensor_length_stride_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout){
if (std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return {std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1})};
}
else
{
return {std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride})};
}
};
std::vector<ck::index_t> a_ms_ks_lengths = f_tensor_length_stride_descriptor(M, K, StrideA, ALayout{})[0];
std::vector<ck::index_t> a_ms_ks_strides = f_tensor_length_stride_descriptor(M, K, StrideA, ALayout{})[1];
std::vector<ck::index_t> b_ns_ks_lengths = f_tensor_length_stride_descriptor(N, K, StrideB, Row{})[0];
std::vector<ck::index_t> b_ns_ks_strides = f_tensor_length_stride_descriptor(N, K, StrideB, Row{})[1];
std::vector<ck::index_t> d_ms_ns_lengths = f_tensor_length_stride_descriptor(M, N, 0, Row{})[0];
std::vector<ck::index_t> d_ms_ns_strides = f_tensor_length_stride_descriptor(M, N, 0, Row{})[1];
std::vector<ck::index_t> e_ms_ns_lengths = f_tensor_length_stride_descriptor(M, N, StrideE, ELayout{})[0];
std::vector<ck::index_t> e_ms_ns_strides = f_tensor_length_stride_descriptor(M, N, StrideE, ELayout{})[1];
switch(config.init_method) switch(config.init_method)
{ {
...@@ -70,38 +85,45 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -70,38 +85,45 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
case 2: case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_m_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_m_k_converted(a_m_k); const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n); const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
const Tensor<KernelDDataType> d_m_n_converted(d_m_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
d_m_n_device_buf.ToDevice(d_m_n_converted.mData.data());
#else #else
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m_n_device_buf.ToDevice(d_m_n.mData.data());
#endif #endif
c_m_n_device_buf.SetZero(); e_m_n_device_buf.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceOpInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
...@@ -110,17 +132,20 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -110,17 +132,20 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
#else #else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{static_cast<DDataType*>(d_m_n_device_buf.GetDeviceBuffer())},
#endif #endif
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<EDataType*>(e_m_n_device_buf.GetDeviceBuffer()),
M, a_ms_ks_lengths,
N, a_ms_ks_strides,
K, b_ns_ks_lengths,
StrideA, b_ns_ks_strides,
StrideB, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
StrideC, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
KBatch); KBatch);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
...@@ -135,22 +160,23 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -135,22 +160,23 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
if(config.do_verification) if(config.do_verification)
{ {
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<ADataType,
BDataType, BDataType,
CDataType, DDataType
EDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CDEElementOp>;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, e_m_n_host_result, d_m_n, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -164,7 +190,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -164,7 +190,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
} }
else else
{ {
pass &= ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); pass &= ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData);
} }
} }
......
...@@ -27,32 +27,43 @@ using F32 = float; ...@@ -27,32 +27,43 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using CLayout = Row; using DsLayout = Row;
using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<F16>;
using EDataType = F16;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // clang-format off
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| using DeviceOpInstanceKKN = ck::tensor_operation::device::
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on // clang-format on
using DeviceOpInstance = DeviceOpInstanceKKN;
#include "run_splitK_gemm_bias_example.inc" #include "run_splitK_gemm_bias_example.inc"
int main(int argc, char* argv[]) { return !run_splitK_gemm_bias_example(argc, argv); } int main(int argc, char* argv[]) { return !run_splitK_gemm_bias_example(argc, argv); }
...@@ -59,6 +59,43 @@ struct DeviceBatchedContractionMultipleD : public BaseOperator ...@@ -59,6 +59,43 @@ struct DeviceBatchedContractionMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceSplitKContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::vector<index_t>& a_gs_ms_ns_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
const std::vector<index_t>& e_gs_ms_ns_lengths,
const std::vector<index_t>& e_gs_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
const index_t k_batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -195,8 +195,8 @@ template <index_t NumDimG, ...@@ -195,8 +195,8 @@ template <index_t NumDimG,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedContractionMultipleD_Xdl_CShuffle struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
: public DeviceBatchedContractionMultipleD<NumDimG, : public DeviceSplitKContractionMultipleD<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
...@@ -208,7 +208,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -208,7 +208,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedContractionMultipleD_Xdl_CShuffle; using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -658,7 +658,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -658,7 +658,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
const std::vector<index_t>& e_gs_ms_ns_strides, const std::vector<index_t>& e_gs_ms_ns_strides,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op,
const index_t KBatch)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
...@@ -680,7 +681,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -680,7 +681,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, KBatch)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -1056,7 +1057,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -1056,7 +1057,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedContractionMultipleD_Xdl_CShuffle" str << "DeviceSplitKContractionMultipleD_Xdl_CShuffle"
<< "<" << "<"
<< NumDimG << ", " << NumDimG << ", "
<< NumDimM << ", " << NumDimM << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment