Commit 8354aad7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Make ref_contraction generic and extend interface tests

parent 1864dfe1
...@@ -249,6 +249,8 @@ int main(int argc, char* argv[]) ...@@ -249,6 +249,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -258,24 +260,32 @@ int main(int argc, char* argv[]) ...@@ -258,24 +260,32 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
CDEElementOp,
true,
DDataType>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks, auto ref_argument = ref_gemm.MakeArgument(
b_ns_ks, a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
d_ms_ns,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
......
...@@ -249,6 +249,8 @@ int main(int argc, char* argv[]) ...@@ -249,6 +249,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -258,24 +260,32 @@ int main(int argc, char* argv[]) ...@@ -258,24 +260,32 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
CDEElementOp,
true,
DDataType>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks, auto ref_argument = ref_gemm.MakeArgument(
b_ns_ks, a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
d_ms_ns,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
......
...@@ -232,6 +232,8 @@ int main(int argc, char* argv[]) ...@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -241,24 +243,32 @@ int main(int argc, char* argv[]) ...@@ -241,24 +243,32 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
CDEElementOp,
false>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{}); Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks, auto ref_argument = ref_gemm.MakeArgument(
b_ns_ks, a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
empty_tensor,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
......
...@@ -232,6 +232,8 @@ int main(int argc, char* argv[]) ...@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -241,24 +243,32 @@ int main(int argc, char* argv[]) ...@@ -241,24 +243,32 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
CDEElementOp,
false>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{}); Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks, auto ref_argument = ref_gemm.MakeArgument(
b_ns_ks, a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
empty_tensor,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
......
...@@ -23,13 +23,10 @@ template <ck::index_t NumDimM, ...@@ -23,13 +23,10 @@ template <ck::index_t NumDimM,
ck::index_t NumDimK, ck::index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename EDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool UseDToBinaryOp,
typename DDataType = float,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{ {
...@@ -38,29 +35,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -38,29 +35,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
Argument(const Tensor<ADataType>& a_ms_ks, Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns, Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op)
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks}, : a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks}, b_ns_ks_{b_ns_ks},
d_ms_ns_{d_ms_ns}, c_ms_ns_{c_ms_ns},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op}
cde_element_op_{cde_element_op}
{ {
} }
const Tensor<ADataType>& a_ms_ks_; const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_; const Tensor<BDataType>& b_ns_ks_;
const Tensor<DDataType>& d_ms_ns_; Tensor<CDataType>& c_ms_ns_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
}; };
// Invoker // Invoker
...@@ -68,19 +59,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -68,19 +59,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
using Argument = ReferenceContraction_M2_N2_K2::Argument; using Argument = ReferenceContraction_M2_N2_K2::Argument;
void apply_unary_op(const CDEElementwiseOperation& op, EDataType& v_e, AccDataType& v_acc)
{
op(v_e, v_acc);
}
void apply_binary_op(const CDEElementwiseOperation& op,
EDataType& v_e,
AccDataType& v_acc,
DDataType& v_d)
{
op(v_e, v_acc, v_d);
}
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
...@@ -105,26 +83,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -105,26 +83,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
} }
} }
AccDataType v_e; arg.c_ms_ns_(m0, m1, n0, n1) = v_acc;
DDataType v_d =
arg.d_ms_ns_.GetNumOfDimension() == 0 ? 0 : arg.d_ms_ns_(m0, m1, n0, n1);
if constexpr(UseDToBinaryOp)
{
apply_binary_op(arg.cde_element_op_, v_e, v_acc, v_d);
}
else
{
apply_unary_op(arg.cde_element_op_, v_e, v_acc);
}
arg.e_ms_ns_(m0, m1, n0, n1) = v_e;
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0], arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1], arg.c_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2], arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])( arg.c_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -150,24 +116,11 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -150,24 +116,11 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks, static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks, const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns, Tensor<CDataType>& c_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_ms_ks, b_ns_ks, d_ms_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op)
CDEElementwiseOperation cde_element_op)
{ {
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -41,7 +41,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -41,7 +41,7 @@ int profile_contraction_impl(ck::index_t do_verification,
ck::index_t init_method, ck::index_t init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
CDElementOp cd_element_op, CDElementOp cde_element_op,
const std::vector<ck::index_t>& M, const std::vector<ck::index_t>& M,
const std::vector<ck::index_t>& N, const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K, const std::vector<ck::index_t>& K,
...@@ -64,14 +64,14 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -64,14 +64,14 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA)); Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC)); Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC));
Tensor<DataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StridesC)); Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesC));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD)); Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -92,18 +92,18 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -92,18 +92,18 @@ int profile_contraction_impl(ck::index_t do_verification,
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(DataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.SetZero(); e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
const std::vector<index_t> a_ms_ks_lengths = {M[0], M[1], K[0], K[1]}; const std::vector<index_t> a_m_k_lengths = {M[0], M[1], K[0], K[1]};
const std::vector<index_t> b_ns_ks_lengths = {N[0], N[1], K[0], K[1]}; const std::vector<index_t> b_n_k_lengths = {N[0], N[1], K[0], K[1]};
const std::vector<index_t> c_ms_ns_lengths = {M[0], M[1], N[0], N[1]}; const std::vector<index_t> c_m_n_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_ms_ns_lengths = {M[0], M[1], N[0], N[1]}; const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
const auto a_element_op = AElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
...@@ -129,8 +129,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -129,8 +129,8 @@ int profile_contraction_impl(ck::index_t do_verification,
// Run reference op // Run reference op
if(do_verification) if(do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2< using ReferenceGemmInstance =
NumDim, ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim,
NumDim, NumDim,
NumDim, NumDim,
DataType, DataType,
...@@ -138,21 +138,41 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -138,21 +138,41 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
DataType, DataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
CDElementOp,
std::is_same<CDElementOp, Bilinear>::value,
DataType>;
auto ref_op = ReferenceGemmInstance{}; auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
if constexpr(std::is_same<CDElementOp, Scale>::value) Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC));
d_m_n = Tensor<DataType>(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_op.MakeArgument( auto ref_argument =
a_m_k, b_k_n, d_m_n, c_m_n_host_result, a_element_op, b_element_op, cd_element_op); ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_m_n_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_m_n_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1)
{
if constexpr(is_same<CDElementOp, Bilinear>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1),
d_m_n(m0, m1, n0, n1));
}
else if constexpr(is_same<CDElementOp, Scale>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1));
}
}
}
}
}
} }
std::string best_op_name; std::string best_op_name;
...@@ -170,18 +190,18 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -170,18 +190,18 @@ int profile_contraction_impl(ck::index_t do_verification,
static_cast<DataType*>(a_device_buf.GetDeviceBuffer()), static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()), static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
static_cast<DataType*>(c_device_buf.GetDeviceBuffer()), static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths, a_m_k_lengths,
StridesA, StridesA,
b_ns_ks_lengths, b_n_k_lengths,
StridesB, StridesB,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths},
std::array<std::vector<ck::index_t>, 1>{StridesD}, std::array<std::vector<ck::index_t>, 1>{StridesD},
c_ms_ns_lengths, c_m_n_lengths,
StridesC, StridesC,
a_element_op, a_element_op,
b_element_op, b_element_op,
cd_element_op); cde_element_op);
} }
else else
{ {
...@@ -189,18 +209,18 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -189,18 +209,18 @@ int profile_contraction_impl(ck::index_t do_verification,
op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()), static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 0>{}, std::array<const void*, 0>{},
static_cast<DataType*>(c_device_buf.GetDeviceBuffer()), static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths, a_m_k_lengths,
StridesA, StridesA,
b_ns_ks_lengths, b_n_k_lengths,
StridesB, StridesB,
std::array<std::vector<ck::index_t>, 0>{}, std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{}, std::array<std::vector<ck::index_t>, 0>{},
c_ms_ns_lengths, c_m_n_lengths,
StridesC, StridesC,
a_element_op, a_element_op,
b_element_op, b_element_op,
cd_element_op); cde_element_op);
} }
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -212,7 +232,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -212,7 +232,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel // re-init C to zero before profiling next kernel
c_device_buf.SetZero(); e_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
...@@ -242,12 +262,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -242,12 +262,12 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
float threshold = float threshold =
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon(); static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
pass = pass & ck::utils::check_err(c_m_n_device_result, pass = pass & ck::utils::check_err(e_m_n_device_result,
c_m_n_host_result, e_m_n_host_result,
"Error: incorrect results!", "Error: incorrect results!",
threshold, threshold,
threshold); threshold);
...@@ -256,9 +276,9 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -256,9 +276,9 @@ int profile_contraction_impl(ck::index_t do_verification,
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -23,7 +23,7 @@ template <typename DataTypeA, ...@@ -23,7 +23,7 @@ template <typename DataTypeA,
typename DataTypeB, typename DataTypeB,
typename DataTypeC, typename DataTypeC,
typename DataTypeD, typename DataTypeD,
int NumDim> ck::index_t NumDim>
class ContractionDeviceWrapper class ContractionDeviceWrapper
{ {
...@@ -40,10 +40,27 @@ class ContractionDeviceWrapper ...@@ -40,10 +40,27 @@ class ContractionDeviceWrapper
Bilinear>; Bilinear>;
public: public:
ContractionDeviceWrapper(std::vector<ck::index_t>& Dims, std::vector<ck::index_t>& Strides)
: InputDims_(Dims), OutputDims_(Dims), InputStrides_(Strides), OutputStrides_(Strides)
{
}
ContractionDeviceWrapper(std::vector<ck::index_t>& InDims,
std::vector<ck::index_t>& OutDims,
std::vector<ck::index_t>& InStrides,
std::vector<ck::index_t>& OutStrides)
: InputDims_(InDims),
OutputDims_(OutDims),
InputStrides_(InStrides),
OutputStrides_(OutStrides)
{
}
std::vector<ck::index_t>& InputDims_;
std::vector<ck::index_t>& OutputDims_;
std::vector<ck::index_t>& InputStrides_;
std::vector<ck::index_t>& OutputStrides_;
bool IsSupported() const bool IsSupported() const
{ {
std::vector<ck::index_t> dummy_dims(NumDim * 2, 4);
std::vector<ck::index_t> dummy_strides(NumDim * 2, 1);
bool supported = false; bool supported = false;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -56,14 +73,14 @@ class ContractionDeviceWrapper ...@@ -56,14 +73,14 @@ class ContractionDeviceWrapper
nullptr, nullptr,
std::array<const void*, 1>{nullptr}, std::array<const void*, 1>{nullptr},
nullptr, nullptr,
dummy_dims, InputStrides_,
dummy_strides, InputStrides_,
dummy_dims, InputStrides_,
dummy_strides, InputStrides_,
std::array<std::vector<ck::index_t>, 1>{dummy_dims}, std::array<std::vector<ck::index_t>, 1>{InputStrides_},
std::array<std::vector<ck::index_t>, 1>{dummy_strides}, std::array<std::vector<ck::index_t>, 1>{InputStrides_},
dummy_dims, OutputDims_,
dummy_strides, OutputStrides_,
Pass{}, Pass{},
Pass{}, Pass{},
Bilinear{1.f, 1.f}); Bilinear{1.f, 1.f});
...@@ -76,9 +93,11 @@ class ContractionDeviceWrapper ...@@ -76,9 +93,11 @@ class ContractionDeviceWrapper
TEST(TestContractionInterface, IncorrectNumDims) TEST(TestContractionInterface, IncorrectNumDims)
{ {
ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d; std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d; std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d; ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d(Dims[0], Strides[0]);
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d(Dims[1], Strides[1]);
ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d(Dims[2], Strides[2]);
EXPECT_FALSE(wrapper_1d.IsSupported()); EXPECT_FALSE(wrapper_1d.IsSupported());
EXPECT_TRUE(wrapper_2d.IsSupported()); EXPECT_TRUE(wrapper_2d.IsSupported());
EXPECT_FALSE(wrapper_3d.IsSupported()); EXPECT_FALSE(wrapper_3d.IsSupported());
...@@ -86,13 +105,30 @@ TEST(TestContractionInterface, IncorrectNumDims) ...@@ -86,13 +105,30 @@ TEST(TestContractionInterface, IncorrectNumDims)
TEST(TestContractionInterface, IncorrectDataTypes) TEST(TestContractionInterface, IncorrectDataTypes)
{ {
ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1; std::vector<ck::index_t> Dims = {4, 4, 4, 4};
ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2; std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1(Dims, Strides);
ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2(Dims, Strides);
EXPECT_FALSE(wrapper_1.IsSupported()); EXPECT_FALSE(wrapper_1.IsSupported());
EXPECT_FALSE(wrapper_2.IsSupported()); EXPECT_FALSE(wrapper_2.IsSupported());
} }
// TEST(TestContractionInterface, CornerCases) TEST(TestContractionInterface, GridwiseGemm)
// { {
// EXPECT_FALSE() std::vector<ck::index_t> InDims = {1, 2, 3, 4};
// } std::vector<ck::index_t> InStrides = {24, 12, 4, 1};
std::vector<ck::index_t> OutDims = {4, 3, 2, 1};
std::vector<ck::index_t> OutStrides = {6, 2, 1, 1};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(InDims, OutDims, InStrides, OutStrides);
EXPECT_FALSE(wrapper.IsSupported());
}
TEST(TestContractionInterface, MemoryAccess)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {4, 16, 64, 256};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(Dims, Strides);
EXPECT_FALSE(wrapper.IsSupported());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment