"...composable_kernel_rocm.git" did not exist on "6d4450ef155c39af9ede2cd171be40ee06db9939"
Commit e30e7a8c authored by muozturk's avatar muozturk
Browse files

test case for complex contraction bilinear

parent b45dd4d6
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES)
add_gtest_executable(test_complex_contraction_bilinear test_complex_contraction_bilinear.cpp)
target_link_libraries(test_complex_contraction_bilinear PRIVATE utility device_contraction_bilinear_instance)
add_gtest_executable(test_complex_contraction_bilinear_interface test_complex_contraction_bilinear_interface.cpp)
target_link_libraries(test_complex_contraction_bilinear_interface PRIVATE utility device_contraction_bilinear_instance)
set(target 1)
endif()
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <memory>
#include <initializer_list>
#include <vector>
#include <tuple>
#include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
struct Dimensions
{
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
};
template <typename Tuple>
class TestContraction : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CDLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>;
using DTupleDataType = std::tuple_element_t<4, Tuple>;
using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
{{16, 16}, {32, 32}, {16, 16}}};
std::vector<ck::index_t> init_methods = {1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op;
void Run()
{
for(auto& dimension_params : dimension_list)
{
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods)
{
bool pass =
ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDLayout,
DataType,
ComputeDataType,
DTupleDataType,
CDElementOp>(true /*do_verification*/,
init_method,
false /*do_logs*/,
false /*time_kernel*/,
*p_cd_element_op,
dimension_params.M,
dimension_params.N,
dimension_params.K,
StridesA,
StridesB,
StridesC,
StridesD);
EXPECT_TRUE(pass);
}
}
}
};
template <typename Tuple>
class TestContractionBilinear : public TestContraction<Tuple>
{
};
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using BilinearKernelTypes =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST(TestContractionBilinear, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
#include "ck/library/utility/device_memory.hpp"
using Pass = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using F64 = double;
template <ck::index_t ABlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t CDEBlockTransferScalarPerVector>
class ContractionInstanceWrapper
{
public:
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr ck::index_t NumDim = 2;
// clang-format off
using ContractionDeviceInstance = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector, F32>;
// clang-format on
bool isSupported(std::vector<ck::index_t>& ADims,
std::vector<ck::index_t>& BDims,
std::vector<ck::index_t>& DDims,
std::vector<ck::index_t>& EDims,
std::vector<ck::index_t>& AStrides,
std::vector<ck::index_t>& BStrides,
std::vector<ck::index_t>& DStrides,
std::vector<ck::index_t>& EStrides) const
{
auto contraction = ContractionDeviceInstance{};
auto argument = contraction.MakeArgument(nullptr,
nullptr,
std::array<const void*, 1>{nullptr},
nullptr,
ADims,
AStrides,
BDims,
BStrides,
std::array<std::vector<ck::index_t>, 1>{DDims},
std::array<std::vector<ck::index_t>, 1>{DStrides},
EDims,
EStrides,
Pass{},
Pass{},
Bilinear{1.f, 1.f});
return contraction.IsSupportedArgument(argument);
}
};
template <typename DataTypeA,
typename DataTypeB,
typename DataTypeC,
typename DataTypeD,
ck::index_t NumDim>
class ContractionDeviceOpWrapper
{
protected:
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
NumDim,
NumDim,
DataTypeA,
DataTypeB,
ck::Tuple<DataTypeC>,
DataTypeD,
Pass,
Pass,
Bilinear>;
public:
bool IsSupportedInstance(std::vector<ck::index_t>& Dims,
std::vector<ck::index_t>& Strides) const
{
bool supported = false;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr =
op_ptr->MakeArgumentPointer(nullptr,
nullptr,
std::array<const void*, 1>{nullptr},
nullptr,
Dims,
Strides,
Dims,
Strides,
std::array<std::vector<ck::index_t>, 1>{Dims},
std::array<std::vector<ck::index_t>, 1>{Strides},
Dims,
Strides,
Pass{},
Pass{},
Bilinear{1.f, 1.f});
supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get());
}
return supported;
}
};
TEST(TestContractionInterface, IncorrectNumDims)
{
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
}
TEST(TestContractionInterface, IncorrectDataTypes)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceOpWrapper<F32, F32, F64, F64, 2> wrapper_1;
ContractionDeviceOpWrapper<F64, F64, F32, F32, 2> wrapper_2;
EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides));
EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides));
}
TEST(TestContractionSupportedArgs, ABMemoryAccess)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
std::vector<ck::index_t> StridesM1 = {4, 1, 64, 16};
std::vector<ck::index_t> StridesK1 = {64, 16, 4, 1};
std::vector<ck::index_t> InvalidStrides = {4, 4, 4, 4};
// Memory access to A
ContractionInstanceWrapper<1, 2, 4> wrapperA1;
ContractionInstanceWrapper<2, 2, 4> wrapperA2;
EXPECT_FALSE(
wrapperA1.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
EXPECT_FALSE(
wrapperA2.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
EXPECT_TRUE(
wrapperA1.isSupported(Dims, Dims, Dims, Dims, StridesM1, Strides, Strides, Strides));
EXPECT_TRUE(
wrapperA2.isSupported(Dims, Dims, Dims, Dims, StridesK1, Strides, Strides, Strides));
// Memory access to B
ContractionInstanceWrapper<2, 1, 4> wrapperB1;
ContractionInstanceWrapper<2, 2, 4> wrapperB2;
EXPECT_FALSE(
wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
EXPECT_FALSE(
wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
EXPECT_TRUE(
wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, StridesM1, Strides, Strides));
EXPECT_TRUE(
wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, StridesK1, Strides, Strides));
}
TEST(TestContractionSupportedArgs, DEMemoryAccess)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
std::vector<ck::index_t> InvalidStrides = {64, 16, 1, 4};
ContractionInstanceWrapper<2, 2, 4> wrapper;
// Memory access to D
EXPECT_FALSE(
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides));
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
// Memory access to E
EXPECT_FALSE(
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides));
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment