"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "3c0001768805b13a7d630cd03928ad85245d4bf7"
Commit 613dcc6b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove elementwise-op objects from interfaces

parent 9fdc3fc8
......@@ -89,10 +89,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
StrideC);
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -356,10 +356,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
index_t StrideC)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -385,9 +382,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
N,
GridwiseGemm::CalculateNPadded(N),
StrideC)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
kraw_{K}
{
}
......@@ -402,9 +396,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t kraw_;
};
......@@ -451,9 +442,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N,
......@@ -467,9 +455,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_element_op_,
karg.b_element_op_,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_);
......@@ -480,9 +465,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N,
......@@ -495,9 +477,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_element_op_,
karg.b_element_op_,
karg.c_element_op_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_);
......@@ -554,23 +533,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
index_t StrideC)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op};
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -585,9 +550,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -597,10 +562,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
StrideC);
}
// polymorphic
......
......@@ -20,9 +20,6 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
......@@ -34,9 +31,6 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n)
......@@ -48,9 +42,6 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n);
......@@ -58,9 +49,6 @@ __global__ void
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m_n;
......@@ -339,9 +327,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n)
......@@ -356,8 +341,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = MakeBlock2CTileMap(c_grid_desc_m_n);
const auto block_work_idx =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment