Commit 6d2d39ba authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Extend test_contraction_interface

parent 1abe377b
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" #include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
...@@ -16,15 +18,65 @@ ...@@ -16,15 +18,65 @@
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float; using F32 = float;
using F64 = double; using F64 = double;
template <ck::index_t ABlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t CDEBlockTransferScalarPerVector>
class ContractionInstanceWrapper
{
public:
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr ck::index_t NumDim = 2;
// clang-format off
using ContractionDeviceInstance = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
// clang-format on
bool isSupported(std::vector<ck::index_t>& ADims,
std::vector<ck::index_t>& BDims,
std::vector<ck::index_t>& DDims,
std::vector<ck::index_t>& EDims,
std::vector<ck::index_t>& AStrides,
std::vector<ck::index_t>& BStrides,
std::vector<ck::index_t>& DStrides,
std::vector<ck::index_t>& EStrides) const
{
auto contraction = ContractionDeviceInstance{};
auto argument = contraction.MakeArgument(nullptr,
nullptr,
std::array<const void*, 1>{nullptr},
nullptr,
ADims,
AStrides,
BDims,
BStrides,
std::array<std::vector<ck::index_t>, 1>{DDims},
std::array<std::vector<ck::index_t>, 1>{DStrides},
EDims,
EStrides,
Pass{},
Pass{},
Bilinear{1.f, 1.f});
return contraction.IsSupportedArgument(argument);
}
};
template <typename DataTypeA, template <typename DataTypeA,
typename DataTypeB, typename DataTypeB,
typename DataTypeC, typename DataTypeC,
typename DataTypeD, typename DataTypeD,
ck::index_t NumDim> ck::index_t NumDim>
class ContractionDeviceWrapper class ContractionDeviceOpWrapper
{ {
protected: protected:
...@@ -40,26 +92,8 @@ class ContractionDeviceWrapper ...@@ -40,26 +92,8 @@ class ContractionDeviceWrapper
Bilinear>; Bilinear>;
public: public:
ContractionDeviceWrapper(std::vector<ck::index_t>& Dims, std::vector<ck::index_t>& Strides) bool IsSupportedInstance(std::vector<ck::index_t>& Dims,
: InputDims_(Dims), OutputDims_(Dims), InputStrides_(Strides), OutputStrides_(Strides) std::vector<ck::index_t>& Strides) const
{
}
ContractionDeviceWrapper(std::vector<ck::index_t>& InDims,
std::vector<ck::index_t>& OutDims,
std::vector<ck::index_t>& InStrides,
std::vector<ck::index_t>& OutStrides)
: InputDims_(InDims),
OutputDims_(OutDims),
InputStrides_(InStrides),
OutputStrides_(OutStrides)
{
}
std::vector<ck::index_t>& InputDims_;
std::vector<ck::index_t>& OutputDims_;
std::vector<ck::index_t>& InputStrides_;
std::vector<ck::index_t>& OutputStrides_;
bool IsSupported() const
{ {
bool supported = false; bool supported = false;
...@@ -73,14 +107,14 @@ class ContractionDeviceWrapper ...@@ -73,14 +107,14 @@ class ContractionDeviceWrapper
nullptr, nullptr,
std::array<const void*, 1>{nullptr}, std::array<const void*, 1>{nullptr},
nullptr, nullptr,
InputStrides_, Dims,
InputStrides_, Strides,
InputStrides_, Dims,
InputStrides_, Strides,
std::array<std::vector<ck::index_t>, 1>{InputStrides_}, std::array<std::vector<ck::index_t>, 1>{Dims},
std::array<std::vector<ck::index_t>, 1>{InputStrides_}, std::array<std::vector<ck::index_t>, 1>{Strides},
OutputDims_, Dims,
OutputStrides_, Strides,
Pass{}, Pass{},
Pass{}, Pass{},
Bilinear{1.f, 1.f}); Bilinear{1.f, 1.f});
...@@ -95,40 +129,67 @@ TEST(TestContractionInterface, IncorrectNumDims) ...@@ -95,40 +129,67 @@ TEST(TestContractionInterface, IncorrectNumDims)
{ {
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}}; std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}}; std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d(Dims[0], Strides[0]); ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d(Dims[1], Strides[1]); ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d(Dims[2], Strides[2]); ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
EXPECT_FALSE(wrapper_1d.IsSupported()); EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
EXPECT_TRUE(wrapper_2d.IsSupported()); EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
EXPECT_FALSE(wrapper_3d.IsSupported()); EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
} }
TEST(TestContractionInterface, IncorrectDataTypes) TEST(TestContractionInterface, IncorrectDataTypes)
{ {
std::vector<ck::index_t> Dims = {4, 4, 4, 4}; std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1}; std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1(Dims, Strides); ContractionDeviceOpWrapper<F32, F32, F64, F64, 2> wrapper_1;
ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2(Dims, Strides); ContractionDeviceOpWrapper<F64, F64, F32, F32, 2> wrapper_2;
EXPECT_FALSE(wrapper_1.IsSupported()); EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides));
EXPECT_FALSE(wrapper_2.IsSupported()); EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides));
} }
TEST(TestContractionInterface, GridwiseGemm) TEST(TestContractionSupportedArgs, ABMemoryAccess)
{ {
std::vector<ck::index_t> InDims = {1, 2, 3, 4}; std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> InStrides = {24, 12, 4, 1}; std::vector<ck::index_t> Strides = {64, 16, 4, 1};
std::vector<ck::index_t> OutDims = {4, 3, 2, 1}; std::vector<ck::index_t> StridesM1 = {4, 1, 64, 16};
std::vector<ck::index_t> OutStrides = {6, 2, 1, 1}; std::vector<ck::index_t> StridesK1 = {64, 16, 4, 1};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(InDims, OutDims, InStrides, OutStrides); std::vector<ck::index_t> InvalidStrides = {4, 4, 4, 4};
// Memory access to A
EXPECT_FALSE(wrapper.IsSupported()); ContractionInstanceWrapper<1, 2, 4> wrapperA1;
ContractionInstanceWrapper<2, 2, 4> wrapperA2;
EXPECT_FALSE(
wrapperA1.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
EXPECT_FALSE(
wrapperA2.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
EXPECT_TRUE(
wrapperA1.isSupported(Dims, Dims, Dims, Dims, StridesM1, Strides, Strides, Strides));
EXPECT_TRUE(
wrapperA2.isSupported(Dims, Dims, Dims, Dims, StridesK1, Strides, Strides, Strides));
// Memory access to B
ContractionInstanceWrapper<2, 1, 4> wrapperB1;
ContractionInstanceWrapper<2, 2, 4> wrapperB2;
EXPECT_FALSE(
wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
EXPECT_FALSE(
wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
EXPECT_TRUE(
wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, StridesM1, Strides, Strides));
EXPECT_TRUE(
wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, StridesK1, Strides, Strides));
} }
TEST(TestContractionInterface, MemoryAccess) TEST(TestContractionSupportedArgs, DEMemoryAccess)
{ {
std::vector<ck::index_t> Dims = {4, 4, 4, 4}; std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {4, 16, 64, 256}; std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(Dims, Strides); std::vector<ck::index_t> InvalidStrides = {64, 16, 1, 4};
ContractionInstanceWrapper<2, 2, 4> wrapper;
EXPECT_FALSE(wrapper.IsSupported()); // Memory access to D
EXPECT_FALSE(
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides));
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
// Memory access to E
EXPECT_FALSE(
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides));
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment