// SPDX-License-Identifier: MIT // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "gtest/gtest.h" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" #include "ck/library/utility/device_memory.hpp" using Pass = ck::tensor_operation::element_wise::PassThrough; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using F32 = float; using F64 = double; template class ContractionDeviceWrapper { protected: using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD, DataTypeD, Pass, Pass, Bilinear>; public: ContractionDeviceWrapper(std::vector& Dims, std::vector& Strides) : InputDims_(Dims), OutputDims_(Dims), InputStrides_(Strides), OutputStrides_(Strides) { } ContractionDeviceWrapper(std::vector& InDims, std::vector& OutDims, std::vector& InStrides, std::vector& OutStrides) : InputDims_(InDims), OutputDims_(OutDims), InputStrides_(InStrides), OutputStrides_(OutStrides) { } std::vector& InputDims_; std::vector& OutputDims_; std::vector& InputStrides_; std::vector& OutputStrides_; bool IsSupported() const { bool supported = false; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); for(auto& op_ptr : op_ptrs) { auto argument_ptr = op_ptr->MakeArgumentPointer(nullptr, nullptr, std::array{nullptr}, nullptr, InputStrides_, InputStrides_, InputStrides_, InputStrides_, std::array, 1>{InputStrides_}, std::array, 1>{InputStrides_}, OutputDims_, OutputStrides_, Pass{}, Pass{}, Bilinear{1.f, 1.f}); supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get()); } return supported; } }; TEST(TestContractionInterface, IncorrectNumDims) { std::vector> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}}; std::vector> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}}; ContractionDeviceWrapper wrapper_1d(Dims[0], Strides[0]); ContractionDeviceWrapper wrapper_2d(Dims[1], Strides[1]); ContractionDeviceWrapper wrapper_3d(Dims[2], Strides[2]); EXPECT_FALSE(wrapper_1d.IsSupported()); EXPECT_TRUE(wrapper_2d.IsSupported()); EXPECT_FALSE(wrapper_3d.IsSupported()); } TEST(TestContractionInterface, IncorrectDataTypes) { std::vector Dims = {4, 4, 4, 4}; std::vector Strides = {64, 16, 4, 1}; ContractionDeviceWrapper wrapper_1(Dims, Strides); ContractionDeviceWrapper wrapper_2(Dims, Strides); EXPECT_FALSE(wrapper_1.IsSupported()); EXPECT_FALSE(wrapper_2.IsSupported()); } TEST(TestContractionInterface, GridwiseGemm) { std::vector InDims = {1, 2, 3, 4}; std::vector InStrides = {24, 12, 4, 1}; std::vector OutDims = {4, 3, 2, 1}; std::vector OutStrides = {6, 2, 1, 1}; ContractionDeviceWrapper wrapper(InDims, OutDims, InStrides, OutStrides); EXPECT_FALSE(wrapper.IsSupported()); } TEST(TestContractionInterface, MemoryAccess) { std::vector Dims = {4, 4, 4, 4}; std::vector Strides = {4, 16, 64, 256}; ContractionDeviceWrapper wrapper(Dims, Strides); EXPECT_FALSE(wrapper.IsSupported()); }