Unverified Commit 5b10dae6 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

Add gemm universal bf16 instances (#1484)



* revert ckprofiler change

* temp save

* Add test and test pass

* test pass

* Fix bug inside rotating buffer when tensor is not packed

* bug fix

* clang format

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
parent 52410b49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances<GemmMKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances<GemmMPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances<Intrawave, GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances<Intrawave,
GemmMKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances<Interwave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances<Interwave, GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances<Interwave,
GemmMKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -171,6 +171,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
......
......@@ -28,6 +28,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
......@@ -56,6 +88,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_MK_KN, PaddK)
{
std::vector<int> Ms{127};
......@@ -84,6 +148,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_MK_KN, Regular)
{
std::vector<int> Ms{512};
......@@ -111,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmUniversal_KM_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmUniversal_KM_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
{
int StrideA = M;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
......@@ -40,6 +40,18 @@ class TestGemmUniversal_MK_NK
{
};
template <typename Tuple>
class TestGemmUniversal_KM_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_KM_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
......@@ -61,9 +73,22 @@ using KernelTypes_MK_NK = ::testing::Types<
#endif
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN);
TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK);
#include "test_gemm_universal_ut_cases.inc"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment