Commit 613dcc6b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove elementwise-op objects from interfaces

parent 9fdc3fc8
...@@ -89,10 +89,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -89,10 +89,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC);
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -356,10 +356,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -356,10 +356,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -385,9 +382,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -385,9 +382,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
N, N,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
StrideC)}, StrideC)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
kraw_{K} kraw_{K}
{ {
} }
...@@ -402,9 +396,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -402,9 +396,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t kraw_; index_t kraw_;
}; };
...@@ -451,9 +442,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -451,9 +442,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N, DeviceOp::CGridDesc_M_N,
...@@ -467,9 +455,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -467,9 +455,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg.p_a_grid_, karg.p_a_grid_,
karg.p_b_grid_, karg.p_b_grid_,
karg.p_c_grid_, karg.p_c_grid_,
karg.a_element_op_,
karg.b_element_op_,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_, karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_, karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_); karg.c_grid_desc_m_n_);
...@@ -480,9 +465,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -480,9 +465,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N, DeviceOp::CGridDesc_M_N,
...@@ -495,9 +477,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -495,9 +477,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg.p_a_grid_, karg.p_a_grid_,
karg.p_b_grid_, karg.p_b_grid_,
karg.p_c_grid_, karg.p_c_grid_,
karg.a_element_op_,
karg.b_element_op_,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_, karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_, karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_); karg.c_grid_desc_m_n_);
...@@ -554,23 +533,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -554,23 +533,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -585,9 +550,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -585,9 +550,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation,
BElementwiseOperation b_element_op, BElementwiseOperation,
CElementwiseOperation c_element_op) override CElementwiseOperation) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -597,10 +562,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -597,10 +562,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC);
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -20,9 +20,6 @@ namespace ck { ...@@ -20,9 +20,6 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -34,9 +31,6 @@ __global__ void ...@@ -34,9 +31,6 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n) const CGridDesc_M_N c_grid_desc_m_n)
...@@ -48,9 +42,6 @@ __global__ void ...@@ -48,9 +42,6 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n); c_grid_desc_m_n);
...@@ -58,9 +49,6 @@ __global__ void ...@@ -58,9 +49,6 @@ __global__ void
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m_n; ignore = c_grid_desc_m_n;
...@@ -339,9 +327,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -339,9 +327,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n) const CGridDesc_M_N c_grid_desc_m_n)
...@@ -356,8 +341,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -356,8 +341,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = MakeBlock2CTileMap(c_grid_desc_m_n); const auto block_2_ctile_map = MakeBlock2CTileMap(c_grid_desc_m_n);
const auto block_work_idx = const auto block_work_idx =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment