"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "097fa65be2b1324d32bfebd3b0398b8307dae6aa"
Commit 8713ade3 authored by mtgu0705's avatar mtgu0705
Browse files

Enalbe splitK

parent 7a17ead7
...@@ -45,6 +45,17 @@ using DeviceGemmV2Instance = ...@@ -45,6 +45,17 @@ using DeviceGemmV2Instance =
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 8, 1, 1, S<1, 32, 1, 8>, 8,
// 128, Scale_Block_N, Scale_Block_K,
// 16, 128,
// KPerBlock, 8, 32,
// 16, 16,
// 1, 4,
// S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 32, 32, 0,
// 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>;
// clang-format on // clang-format on
...@@ -273,6 +284,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -273,6 +284,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
Scale_Stride_BN,
static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()), static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()),
KBatch, KBatch,
a_element_op, a_element_op,
......
...@@ -100,6 +100,7 @@ struct DeviceGemmV2BScale : public BaseOperator ...@@ -100,6 +100,7 @@ struct DeviceGemmV2BScale : public BaseOperator
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t StrideScaleB,
const void* p_b_scale, const void* p_b_scale,
ck::index_t KSplit, ck::index_t KSplit,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
......
...@@ -663,6 +663,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -663,6 +663,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideScaleB,
const BScaleDataType* p_b_scale, const BScaleDataType* p_b_scale,
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -678,6 +679,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -678,6 +679,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideScaleB,
p_b_scale, p_b_scale,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -697,6 +699,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -697,6 +699,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideScaleB,
const void* p_b_scale, const void* p_b_scale,
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -712,6 +715,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -712,6 +715,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideScaleB,
static_cast<const BScaleDataType*>(p_b_scale), static_cast<const BScaleDataType*>(p_b_scale),
KBatch, KBatch,
a_element_op, a_element_op,
......
...@@ -37,18 +37,16 @@ __global__ void ...@@ -37,18 +37,16 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
// GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.scale_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
// p_shared,
// karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, karg.p_b_scale_grid, p_shared, karg); karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
...@@ -72,24 +70,17 @@ __global__ void ...@@ -72,24 +70,17 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
// p_shared_0,
// p_shared_1,
// karg);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid, karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
karg); karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
...@@ -533,6 +524,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -533,6 +524,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
index_t StrideScaleB_,
index_t KBatch_) index_t KBatch_)
: M{M_}, : M{M_},
N{N_}, N{N_},
...@@ -540,6 +532,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -540,6 +532,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
StrideA{StrideA_}, StrideA{StrideA_},
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_}, KBatch{KBatch_},
MPadded{CalculateMPadded(M_)}, MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)}, NPadded{CalculateNPadded(N_)},
...@@ -561,6 +554,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -561,6 +554,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
<< "SA:" << StrideA << ", " << "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", " << "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", " << "SC:" << StrideC << ", "
<< "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KRead:" << KRead << ", "
...@@ -577,6 +571,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -577,6 +571,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t StrideScaleB;
index_t KBatch; index_t KBatch;
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
...@@ -600,13 +595,14 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -600,13 +595,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
index_t StrideScaleB_,
const BScaleType* p_b_scale_grid_, const BScaleType* p_b_scale_grid_,
index_t k_batch_, index_t k_batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_, CElementwiseOperation c_element_op_,
bool is_reduce_ = false) bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}, p_c_grid{p_c_grid_},
...@@ -670,15 +666,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -670,15 +666,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
} }
// // Calculate B scale offset // Calculate B scale offset
// if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// { {
// scale_k_split_offset = blockIdx.z * (karg.K / 64) * karg.StrideB; scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
// } }
// else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
// { {
// scale_k_split_offset = blockIdx.z * (karg.K / 64) * karg.N; scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
// } }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{ {
...@@ -701,7 +697,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -701,7 +697,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t a_k_split_offset; index_t a_k_split_offset;
index_t b_k_split_offset; index_t b_k_split_offset;
// index_t scale_k_split_offset; // New member for scale matrix offset index_t scale_k_split_offset; // New member for scale matrix offset
index_t c_reduce_offset; index_t c_reduce_offset;
}; };
...@@ -1273,6 +1269,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1273,6 +1269,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
template <typename AGridDesc_AK0_M_K1, template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename BScaleGridDesc_BN_AK,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -1285,6 +1282,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1285,6 +1282,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const Problem& problem, const Problem& problem,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock) c_grid_desc_mblock_mperblock_nblock_nperblock)
{ {
...@@ -1295,12 +1293,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1295,12 +1293,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// B Scale grid and buffer // B Scale buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
...@@ -1703,8 +1696,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1703,8 +1696,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
Run<decltype(a_grid_desc_ak0_m_ak1), Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
...@@ -1716,11 +1716,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1716,11 +1716,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem, problem,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock); c_grid_desc_mblock_mperblock_nblock_nperblock);
} }
template <typename AGridDesc_AK0_M_K1, template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename BScaleGridDesc_BN_AK,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -1734,6 +1736,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1734,6 +1736,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const Problem& problem, const Problem& problem,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock) c_grid_desc_mblock_mperblock_nblock_nperblock)
{ {
...@@ -1744,12 +1747,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1744,12 +1747,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// B Scale grid and buffer // B Scale buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
...@@ -2164,8 +2162,14 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2164,8 +2162,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1), Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
...@@ -2178,6 +2182,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2178,6 +2182,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem, problem,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock); c_grid_desc_mblock_mperblock_nblock_nperblock);
} }
}; };
......
...@@ -301,6 +301,7 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -301,6 +301,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
Scale_Stride_BN,
static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()),
kbatch_curr, kbatch_curr,
a_element_op, a_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment