"docs/vscode:/vscode.git/clone" did not exist on "5d7ea6616fc127469f43605464803d8521fcc51d"
Commit 304adaad authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Allow to use any elementwise operator for ref_contraction

parent 93ce856f
......@@ -260,6 +260,7 @@ int main(int argc, char* argv[])
AElementOp,
BElementOp,
CDEElementOp,
true,
DDataType>;
auto ref_gemm = ReferenceOpInstance{};
......
......@@ -260,6 +260,7 @@ int main(int argc, char* argv[])
AElementOp,
BElementOp,
CDEElementOp,
true,
DDataType>;
auto ref_gemm = ReferenceOpInstance{};
......
......@@ -242,7 +242,8 @@ int main(int argc, char* argv[])
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
CDEElementOp,
false>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -242,7 +242,8 @@ int main(int argc, char* argv[])
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
CDEElementOp,
false>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -23,11 +23,12 @@ template <ck::index_t NumDimM,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename CDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool UseDToBinaryOp,
typename DDataType = float,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
......@@ -38,14 +39,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns,
Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
d_ms_ns_{d_ms_ns},
c_ms_ns_{c_ms_ns},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
......@@ -55,7 +56,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
const Tensor<DDataType>& d_ms_ns_;
Tensor<CDataType>& c_ms_ns_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -67,19 +68,17 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
template <typename Op>
void apply_op(Op& op, DDataType& v_d, CDataType& v_c, AccDataType& v_acc)
void apply_unary_op(const CDEElementwiseOperation& op, EDataType& v_e, AccDataType& v_acc)
{
op(v_c, static_cast<AccDataType>(v_d + v_acc));
op(v_e, v_acc);
}
template <>
void apply_op<const Bilinear>(const Bilinear& bilinear,
DDataType& v_d,
CDataType& v_c,
AccDataType& v_acc)
void apply_binary_op(const CDEElementwiseOperation& op,
EDataType& v_e,
AccDataType& v_acc,
DDataType& v_d)
{
bilinear(v_c, v_d, v_acc);
op(v_e, v_acc, v_d);
}
float Run(const Argument& arg)
......@@ -106,19 +105,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
}
AccDataType v_c;
AccDataType v_e;
DDataType v_d =
arg.d_ms_ns_.GetNumOfDimension() == 0 ? 0 : arg.d_ms_ns_(m0, m1, n0, n1);
apply_op(arg.cde_element_op_, v_d, v_c, v_acc);
if constexpr(UseDToBinaryOp)
{
apply_binary_op(arg.cde_element_op_, v_e, v_acc, v_d);
}
else
{
apply_unary_op(arg.cde_element_op_, v_e, v_acc);
}
arg.c_ms_ns_(m0, m1, n0, n1) = v_c;
arg.e_ms_ns_(m0, m1, n0, n1) = v_e;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3])(
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......@@ -145,23 +151,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns,
Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_ms_ks, b_ns_ks, d_ms_ns, c_ms_ns, a_element_op, b_element_op, cde_element_op};
a_ms_ks, b_ns_ks, d_ms_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op, cde_element_op};
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -129,8 +129,8 @@ int profile_contraction_impl(ck::index_t do_verification,
// Run reference op
if(do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<
NumDim,
NumDim,
NumDim,
DataType,
......@@ -140,6 +140,7 @@ int profile_contraction_impl(ck::index_t do_verification,
AElementOp,
BElementOp,
CDElementOp,
std::is_same<CDElementOp, Bilinear>::value,
DataType>;
auto ref_op = ReferenceGemmInstance{};
......
......@@ -125,7 +125,7 @@ TYPED_TEST(TestContractionBilinear, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(0.5f, 0.5f);
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment