Commit 8354aad7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Make ref_contraction generic and extend interface tests

parent 1864dfe1
......@@ -249,6 +249,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
......@@ -258,24 +260,32 @@ int main(int argc, char* argv[])
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp,
true,
DDataType>;
BElementOp>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks,
b_ns_ks,
d_ms_ns,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
......
......@@ -249,6 +249,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
......@@ -258,24 +260,32 @@ int main(int argc, char* argv[])
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp,
true,
DDataType>;
BElementOp>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks,
b_ns_ks,
d_ms_ns,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
......
......@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
......@@ -241,24 +243,32 @@ int main(int argc, char* argv[])
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp,
false>;
BElementOp>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks,
b_ns_ks,
empty_tensor,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
......
......@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
......@@ -241,24 +243,32 @@ int main(int argc, char* argv[])
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp,
false>;
BElementOp>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument = ref_gemm.MakeArgument(a_ms_ks,
b_ns_ks,
empty_tensor,
e_ms_ns_host_result,
a_element_op,
b_element_op,
cde_element_op);
auto ref_argument = ref_gemm.MakeArgument(
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
}
}
}
}
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
......
......@@ -23,13 +23,10 @@ template <ck::index_t NumDimM,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool UseDToBinaryOp,
typename DDataType = float,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
......@@ -38,29 +35,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns,
Tensor<EDataType>& e_ms_ns,
Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
BElementwiseOperation b_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
d_ms_ns_{d_ms_ns},
e_ms_ns_{e_ms_ns},
c_ms_ns_{c_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
b_element_op_{b_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
const Tensor<DDataType>& d_ms_ns_;
Tensor<EDataType>& e_ms_ns_;
Tensor<CDataType>& c_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
......@@ -68,19 +59,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
void apply_unary_op(const CDEElementwiseOperation& op, EDataType& v_e, AccDataType& v_acc)
{
op(v_e, v_acc);
}
void apply_binary_op(const CDEElementwiseOperation& op,
EDataType& v_e,
AccDataType& v_acc,
DDataType& v_d)
{
op(v_e, v_acc, v_d);
}
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
......@@ -105,26 +83,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
}
AccDataType v_e;
DDataType v_d =
arg.d_ms_ns_.GetNumOfDimension() == 0 ? 0 : arg.d_ms_ns_(m0, m1, n0, n1);
if constexpr(UseDToBinaryOp)
{
apply_binary_op(arg.cde_element_op_, v_e, v_acc, v_d);
}
else
{
apply_unary_op(arg.cde_element_op_, v_e, v_acc);
}
arg.e_ms_ns_(m0, m1, n0, n1) = v_e;
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
arg.c_ms_ns_.mDesc.GetLengths()[0],
arg.c_ms_ns_.mDesc.GetLengths()[1],
arg.c_ms_ns_.mDesc.GetLengths()[2],
arg.c_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......@@ -150,24 +116,11 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
const Tensor<DDataType>& d_ms_ns,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_ms_ks, b_ns_ks, d_ms_ns, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
Tensor<CDataType>& c_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
BElementwiseOperation b_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -41,7 +41,7 @@ int profile_contraction_impl(ck::index_t do_verification,
ck::index_t init_method,
bool do_log,
bool time_kernel,
CDElementOp cd_element_op,
CDElementOp cde_element_op,
const std::vector<ck::index_t>& M,
const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K,
......@@ -64,14 +64,14 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC));
Tensor<DataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StridesC));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesC));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
......@@ -92,18 +92,18 @@ int profile_contraction_impl(ck::index_t do_verification,
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(DataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.SetZero();
e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data());
const std::vector<index_t> a_ms_ks_lengths = {M[0], M[1], K[0], K[1]};
const std::vector<index_t> b_ns_ks_lengths = {N[0], N[1], K[0], K[1]};
const std::vector<index_t> c_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> a_m_k_lengths = {M[0], M[1], K[0], K[1]};
const std::vector<index_t> b_n_k_lengths = {N[0], N[1], K[0], K[1]};
const std::vector<index_t> c_m_n_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
......@@ -129,8 +129,8 @@ int profile_contraction_impl(ck::index_t do_verification,
// Run reference op
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<
NumDim,
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim,
NumDim,
NumDim,
DataType,
......@@ -138,21 +138,41 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType,
DataType,
AElementOp,
BElementOp,
CDElementOp,
std::is_same<CDElementOp, Bilinear>::value,
DataType>;
BElementOp>;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
if constexpr(std::is_same<CDElementOp, Scale>::value)
d_m_n = Tensor<DataType>(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesC));
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, d_m_n, c_m_n_host_result, a_element_op, b_element_op, cd_element_op);
auto ref_argument =
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_m_n_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_m_n_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1)
{
if constexpr(is_same<CDElementOp, Bilinear>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1),
d_m_n(m0, m1, n0, n1));
}
else if constexpr(is_same<CDElementOp, Scale>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1));
}
}
}
}
}
}
std::string best_op_name;
......@@ -170,18 +190,18 @@ int profile_contraction_impl(ck::index_t do_verification,
static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
static_cast<DataType*>(c_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths,
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_m_k_lengths,
StridesA,
b_ns_ks_lengths,
b_n_k_lengths,
StridesB,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths},
std::array<std::vector<ck::index_t>, 1>{StridesD},
c_ms_ns_lengths,
c_m_n_lengths,
StridesC,
a_element_op,
b_element_op,
cd_element_op);
cde_element_op);
}
else
{
......@@ -189,18 +209,18 @@ int profile_contraction_impl(ck::index_t do_verification,
op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 0>{},
static_cast<DataType*>(c_device_buf.GetDeviceBuffer()),
a_ms_ks_lengths,
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
a_m_k_lengths,
StridesA,
b_ns_ks_lengths,
b_n_k_lengths,
StridesB,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
c_ms_ns_lengths,
c_m_n_lengths,
StridesC,
a_element_op,
b_element_op,
cd_element_op);
cde_element_op);
}
auto invoker_ptr = op_ptr->MakeInvokerPointer();
......@@ -212,7 +232,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
e_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
......@@ -242,12 +262,12 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
float threshold =
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
pass = pass & ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
pass = pass & ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result,
"Error: incorrect results!",
threshold,
threshold);
......@@ -256,9 +276,9 @@ int profile_contraction_impl(ck::index_t do_verification,
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
}
}
......
......@@ -23,7 +23,7 @@ template <typename DataTypeA,
typename DataTypeB,
typename DataTypeC,
typename DataTypeD,
int NumDim>
ck::index_t NumDim>
class ContractionDeviceWrapper
{
......@@ -40,10 +40,27 @@ class ContractionDeviceWrapper
Bilinear>;
public:
ContractionDeviceWrapper(std::vector<ck::index_t>& Dims, std::vector<ck::index_t>& Strides)
: InputDims_(Dims), OutputDims_(Dims), InputStrides_(Strides), OutputStrides_(Strides)
{
}
ContractionDeviceWrapper(std::vector<ck::index_t>& InDims,
std::vector<ck::index_t>& OutDims,
std::vector<ck::index_t>& InStrides,
std::vector<ck::index_t>& OutStrides)
: InputDims_(InDims),
OutputDims_(OutDims),
InputStrides_(InStrides),
OutputStrides_(OutStrides)
{
}
std::vector<ck::index_t>& InputDims_;
std::vector<ck::index_t>& OutputDims_;
std::vector<ck::index_t>& InputStrides_;
std::vector<ck::index_t>& OutputStrides_;
bool IsSupported() const
{
std::vector<ck::index_t> dummy_dims(NumDim * 2, 4);
std::vector<ck::index_t> dummy_strides(NumDim * 2, 1);
bool supported = false;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......@@ -56,14 +73,14 @@ class ContractionDeviceWrapper
nullptr,
std::array<const void*, 1>{nullptr},
nullptr,
dummy_dims,
dummy_strides,
dummy_dims,
dummy_strides,
std::array<std::vector<ck::index_t>, 1>{dummy_dims},
std::array<std::vector<ck::index_t>, 1>{dummy_strides},
dummy_dims,
dummy_strides,
InputStrides_,
InputStrides_,
InputStrides_,
InputStrides_,
std::array<std::vector<ck::index_t>, 1>{InputStrides_},
std::array<std::vector<ck::index_t>, 1>{InputStrides_},
OutputDims_,
OutputStrides_,
Pass{},
Pass{},
Bilinear{1.f, 1.f});
......@@ -76,9 +93,11 @@ class ContractionDeviceWrapper
TEST(TestContractionInterface, IncorrectNumDims)
{
ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d;
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d(Dims[0], Strides[0]);
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d(Dims[1], Strides[1]);
ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d(Dims[2], Strides[2]);
EXPECT_FALSE(wrapper_1d.IsSupported());
EXPECT_TRUE(wrapper_2d.IsSupported());
EXPECT_FALSE(wrapper_3d.IsSupported());
......@@ -86,13 +105,30 @@ TEST(TestContractionInterface, IncorrectNumDims)
TEST(TestContractionInterface, IncorrectDataTypes)
{
ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1;
ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2;
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1(Dims, Strides);
ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2(Dims, Strides);
EXPECT_FALSE(wrapper_1.IsSupported());
EXPECT_FALSE(wrapper_2.IsSupported());
}
// TEST(TestContractionInterface, CornerCases)
// {
// EXPECT_FALSE()
// }
TEST(TestContractionInterface, GridwiseGemm)
{
std::vector<ck::index_t> InDims = {1, 2, 3, 4};
std::vector<ck::index_t> InStrides = {24, 12, 4, 1};
std::vector<ck::index_t> OutDims = {4, 3, 2, 1};
std::vector<ck::index_t> OutStrides = {6, 2, 1, 1};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(InDims, OutDims, InStrides, OutStrides);
EXPECT_FALSE(wrapper.IsSupported());
}
TEST(TestContractionInterface, MemoryAccess)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {4, 16, 64, 256};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(Dims, Strides);
EXPECT_FALSE(wrapper.IsSupported());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment