// SPDX-License-Identifier: MIT // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include #include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_utils.hpp" using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; using F64 = double; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Scale = ck::tensor_operation::element_wise::Scale; template struct Dimensions { constexpr static ck::index_t NumDimMNK = NDims; std::vector M; std::vector N; std::vector K; }; template class TestContraction : public ::testing::Test { protected: using ALayout = std::tuple_element_t<0, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>; using CDLayout = std::tuple_element_t<2, Tuple>; using DataType = std::tuple_element_t<3, Tuple>; using DTupleDataType = std::tuple_element_t<4, Tuple>; using ComputeDataType = std::tuple_element_t<5, Tuple>; using CDElementOp = std::tuple_element_t<6, Tuple>; std::vector init_methods = {1, 2}; std::unique_ptr p_cd_element_op; template void Run(Dimensions dimension_params) { constexpr ck::index_t NumDimMNK = ck::remove_cvref_t::NumDimMNK; std::vector StridesA(2 * NumDim); std::vector StridesB(2 * NumDim); std::vector StridesC(2 * NumDim); std::vector StridesD(2 * NumDim); const auto& M = dimension_params.M; const auto& N = dimension_params.N; const auto& K = dimension_params.K; auto merge_dims = [](const std::vector& dims01, const std::vector& dims23) { std::vector dims_szt(dims01.begin(), dims01.end()); dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); return dims_szt; }; assign_default_strides(ALayout{}, StridesA, merge_dims(M, K)); assign_default_strides(BLayout{}, StridesB, merge_dims(N, K)); assign_default_strides(CDLayout{}, StridesC, merge_dims(M, N)); assign_default_strides(CDLayout{}, StridesD, merge_dims(M, N)); for(const ck::index_t init_method : init_methods) { bool pass = ck::profiler::profile_contraction_impl(true /*do_verification*/, init_method, false /*do_logs*/, false /*time_kernel*/, *p_cd_element_op, dimension_params.M, dimension_params.N, dimension_params.K, StridesA, StridesB, StridesC, StridesD); EXPECT_TRUE(pass); } } }; template class TestContractionScale : public TestContraction { }; template class TestContractionBilinear : public TestContraction { }; #define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \ std::tuple, \ std::tuple, \ std::tuple, \ std::tuple using BilinearKernelTypes = ::testing::Types, F32, Bilinear), ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F64, Bilinear)>; using ScaleKernelTypes = ::testing::Types, F32, Scale), ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>; TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST(TestContractionBilinear, bilinear) { this->p_cd_element_op = std::make_unique(1.f, 1.f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } TYPED_TEST(TestContractionScale, scale) { this->p_cd_element_op = std::make_unique(1.f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); this->p_cd_element_op = std::make_unique(0.5f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } template class TestContractionScaleMixedPrecision : public TestContraction { }; template class TestContractionBilinearMixedPrecision : public TestContraction { }; using BilinearKernelTypesMixedPrecision = ::testing::Types, F16, Bilinear), ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple, BF16, Bilinear), ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F32, Bilinear), ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple, F32, Bilinear), ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple, F32, Bilinear)>; using ScaleKernelTypesMixedPrecision = ::testing::Types, F16, Scale), ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale), ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale), ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale), ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>; TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision); TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision); TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear) { this->p_cd_element_op = std::make_unique(1.f, 1.f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } TYPED_TEST(TestContractionScaleMixedPrecision, scale) { this->p_cd_element_op = std::make_unique(1.f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); this->p_cd_element_op = std::make_unique(0.5f); this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); // special cases this->template Run<2>({{1, 1}, {16, 8}, {8, 16}}); this->template Run<2>({{8, 16}, {16, 8}, {1, 1}}); this->template Run<2>({{8, 16}, {1, 1}, {8, 16}}); this->template Run<2>({{1, 1}, {1, 1}, {1, 1}}); }