Commit 6e9ef894 authored by rtmadduri's avatar rtmadduri
Browse files

applied changes to tail num lambda, clean up ctrs

parent bf73d297
......@@ -76,7 +76,8 @@ __global__ void
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op gemm_desc_ptr[group_id].block_2_ctile_map_);
karg.c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_);
#else
ignore = gemm_descs_const;
ignore = group_count;
......@@ -137,20 +138,18 @@ template <typename ALayout,
// MultipleD not supported for now.
enable_if_t<is_same_v<DsLayout, ck::Tuple<>> && is_same_v<DsDataType, ck::Tuple<>>,
bool> = false>
>
struct DeviceGroupedGemmXdlSplitKCShuffle
: public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -221,7 +220,7 @@ template <typename ALayout,
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
// GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
......@@ -243,11 +242,8 @@ template <typename ALayout,
Argument(std::vector<const void*>& p_a_grid,
std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch, a_element_op, b_element_op, cde_element_op )
std::vector<GemmDesc>& gemm_descs)
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
......@@ -256,10 +252,7 @@ template <typename ALayout,
std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs,
index_t kbatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
index_t kbatch)
: K_BATCH{kbatch}
{
grid_size_ = 0;
......@@ -307,19 +300,14 @@ template <typename ALayout,
KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]),
type_convert<const BDataType*>(p_b_grid[i]),
{}, // p_ds_grid
type_convert<EDataType*>(p_c_grid[i]),
M,
N,
K,
stride_a,
stride_b,
{}, // StrideDs_
stride_c,
K_BATCH,
a_element_op,
b_element_op,
cde_element_op};
K_BATCH};
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
......@@ -341,8 +329,8 @@ template <typename ALayout,
auto& karg = gemm_kernel_args_[i].karg_;
const auto local_b2c_tile_map = Block2ETileMap{M, N, 4};
index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
const auto local_b2c_tile_map = Block2ETileMap{karg.M, karg.N, 4};
index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(karg.M, karg.N);
grid_size_grp *= K_BATCH;
const index_t block_start = grid_size_;
......@@ -380,16 +368,17 @@ template <typename ALayout,
bool all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split0);
const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split0);
bool all_have_kbatch_gt_one = karg0.KBatch > 1;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(stream_config.log_level_ > 0)
{
karg.Print();
}
const auto& karg = arg.gemm_kernel_args_[i].karg_;
index_t k_grain = karg.KBatch * KPerBlock;
index_t K_split = (karg.K + k_grain - 1) / k_grain * KPerBlock;
......@@ -460,11 +449,16 @@ template <typename ALayout,
rotating_mem.Next();
// clear c mem
// TODO: should be loop here through all groups
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
const auto& karg = trans_arg.karg_;
if(karg.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......@@ -480,11 +474,15 @@ template <typename ALayout,
else
{
// TODO: should be loop here through all groups
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
const auto& karg = trans_arg.karg_;
if(karg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel(
stream_config,
......@@ -537,7 +535,8 @@ template <typename ALayout,
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg true,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
......@@ -546,313 +545,231 @@ template <typename ALayout,
//// TODO: Fix below as above!
else if(calculate_tail_number() == TailNumber::Full)
else if(tail_num == TailNumber::Full)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(calculate_tail_number() == TailNumber::Two)
if(tail_num == TailNumber::Two)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_grouped_gemm_xdl_splitk<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(calculate_tail_number() == TailNumber::Three)
if(tail_num == TailNumber::Three)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_grouped_gemm_xdl_splitk<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(calculate_tail_number() == TailNumber::Four)
if(tail_num == TailNumber::Four)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_grouped_gemm_xdl_splitk<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(calculate_tail_number() == TailNumber::Five)
if(tail_num == TailNumber::Five)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_grouped_gemm_xdl_splitk<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(calculate_tail_number() == TailNumber::Six)
if(tail_num == TailNumber::Six)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_grouped_gemm_xdl_splitk<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(calculate_tail_number() == TailNumber::Seven)
if(tail_num == TailNumber::Seven)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_grouped_gemm_xdl_splitk<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(calculate_tail_number() == TailNumber::One)
if(tail_num == TailNumber::One)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(calculate_tail_number() == TailNumber::Full)
else if(tail_num == TailNumber::Full)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(calculate_tail_number() == TailNumber::Two)
if(tail_num == TailNumber::Two)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(calculate_tail_number() == TailNumber::Three)
if(tail_num == TailNumber::Three)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(calculate_tail_number() == TailNumber::Four)
if(tail_num == TailNumber::Four)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(calculate_tail_number() == TailNumber::Five)
if(tail_num == TailNumber::Five)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(calculate_tail_number() == TailNumber::Six)
if(tail_num == TailNumber::Six)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(calculate_tail_number() == TailNumber::Seven)
if(tail_num == TailNumber::Seven)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
......@@ -862,77 +779,57 @@ template <typename ALayout,
{
if(all_have_kbatch_gt_one)
{
if(calculate_tail_number() == TailNumber::Odd)
if(tail_num == TailNumber::Odd)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(calculate_tail_number() == TailNumber::Odd)
if(tail_num == TailNumber::Odd)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
......@@ -941,78 +838,57 @@ template <typename ALayout,
{
if(all_have_kbatch_gt_one)
{
if(calculate_tail_number() == TailNumber::Odd)
if(tail_num == TailNumber::Odd)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(calculate_tail_number() == TailNumber::Odd)
if(tail_num == TailNumber::Odd)
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
if(all_have_same_tail_number())
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
......@@ -1025,19 +901,21 @@ template <typename ALayout,
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
......@@ -1115,12 +993,11 @@ template <typename ALayout,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation)
{
return Argument{
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op};
return Argument{p_a_grid, p_b_grid, p_c_grid, gemm_descs};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -1132,12 +1009,11 @@ template <typename ALayout,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation) override
{
return std::make_unique<Argument>(
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op);
return std::make_unique<Argument>(p_a_grid, p_b_grid, p_c_grid, gemm_descs);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment