Commit 304adaad authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Allow to use any elementwise operator for ref_contraction

parent 93ce856f
...@@ -260,6 +260,7 @@ int main(int argc, char* argv[]) ...@@ -260,6 +260,7 @@ int main(int argc, char* argv[])
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
true,
DDataType>; DDataType>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
......
...@@ -260,6 +260,7 @@ int main(int argc, char* argv[]) ...@@ -260,6 +260,7 @@ int main(int argc, char* argv[])
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
true,
DDataType>; DDataType>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
......
...@@ -242,7 +242,8 @@ int main(int argc, char* argv[]) ...@@ -242,7 +242,8 @@ int main(int argc, char* argv[])
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp>; CDEElementOp,
false>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
......
...@@ -242,7 +242,8 @@ int main(int argc, char* argv[]) ...@@ -242,7 +242,8 @@ int main(int argc, char* argv[])
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp>; CDEElementOp,
false>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
......
...@@ -23,11 +23,12 @@ template <ck::index_t NumDimM, ...@@ -23,11 +23,12 @@ template <ck::index_t NumDimM,
ck::index_t NumDimK, ck::index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
bool UseDToBinaryOp,
typename DDataType = float, typename DDataType = float,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
...@@ -38,14 +39,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -38,14 +39,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
Argument(const Tensor<ADataType>& a_ms_ks, Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns, const Tensor<DDataType>& d_ms_ns,
Tensor<CDataType>& c_ms_ns, Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks}, : a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks}, b_ns_ks_{b_ns_ks},
d_ms_ns_{d_ms_ns}, d_ms_ns_{d_ms_ns},
c_ms_ns_{c_ms_ns}, e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op}
...@@ -55,7 +56,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -55,7 +56,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
const Tensor<ADataType>& a_ms_ks_; const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_; const Tensor<BDataType>& b_ns_ks_;
const Tensor<DDataType>& d_ms_ns_; const Tensor<DDataType>& d_ms_ns_;
Tensor<CDataType>& c_ms_ns_; Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -67,19 +68,17 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -67,19 +68,17 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
using Argument = ReferenceContraction_M2_N2_K2::Argument; using Argument = ReferenceContraction_M2_N2_K2::Argument;
template <typename Op> void apply_unary_op(const CDEElementwiseOperation& op, EDataType& v_e, AccDataType& v_acc)
void apply_op(Op& op, DDataType& v_d, CDataType& v_c, AccDataType& v_acc)
{ {
op(v_c, static_cast<AccDataType>(v_d + v_acc)); op(v_e, v_acc);
} }
template <> void apply_binary_op(const CDEElementwiseOperation& op,
void apply_op<const Bilinear>(const Bilinear& bilinear, EDataType& v_e,
DDataType& v_d, AccDataType& v_acc,
CDataType& v_c, DDataType& v_d)
AccDataType& v_acc)
{ {
bilinear(v_c, v_d, v_acc); op(v_e, v_acc, v_d);
} }
float Run(const Argument& arg) float Run(const Argument& arg)
...@@ -106,19 +105,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -106,19 +105,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
} }
AccDataType v_c; AccDataType v_e;
DDataType v_d = DDataType v_d =
arg.d_ms_ns_.GetNumOfDimension() == 0 ? 0 : arg.d_ms_ns_(m0, m1, n0, n1); arg.d_ms_ns_.GetNumOfDimension() == 0 ? 0 : arg.d_ms_ns_(m0, m1, n0, n1);
apply_op(arg.cde_element_op_, v_d, v_c, v_acc); if constexpr(UseDToBinaryOp)
{
apply_binary_op(arg.cde_element_op_, v_e, v_acc, v_d);
}
else
{
apply_unary_op(arg.cde_element_op_, v_e, v_acc);
}
arg.c_ms_ns_(m0, m1, n0, n1) = v_c; arg.e_ms_ns_(m0, m1, n0, n1) = v_e;
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[0], arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1], arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2], arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3])( arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -145,23 +151,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -145,23 +151,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks, static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns, const Tensor<DDataType>& d_ms_ns,
Tensor<CDataType>& c_ms_ns, Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{ return Argument{
a_ms_ks, b_ns_ks, d_ms_ns, c_ms_ns, a_element_op, b_element_op, cde_element_op}; a_ms_ks, b_ns_ks, d_ms_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
} }
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks, static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ns_ks,
Tensor<CDataType>& c_ms_ns, Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op, cde_element_op}; return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -129,18 +129,19 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -129,18 +129,19 @@ int profile_contraction_impl(ck::index_t do_verification,
// Run reference op // Run reference op
if(do_verification) if(do_verification)
{ {
using ReferenceGemmInstance = using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim, NumDim,
NumDim, NumDim,
NumDim, NumDim,
DataType, DataType,
DataType, DataType,
DataType, DataType,
DataType, DataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDElementOp, CDElementOp,
DataType>; std::is_same<CDElementOp, Bilinear>::value,
DataType>;
auto ref_op = ReferenceGemmInstance{}; auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
......
...@@ -125,7 +125,7 @@ TYPED_TEST(TestContractionBilinear, bilinear) ...@@ -125,7 +125,7 @@ TYPED_TEST(TestContractionBilinear, bilinear)
{ {
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f); this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run(); this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(0.5f, 0.5f); this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run(); this->Run();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment