"example/vscode:/vscode.git/clone" did not exist on "0bff049ad49c48b189ebe5c142d1a149d10e4cce"
Commit 0c8b4bbf authored by Adam Osewski's avatar Adam Osewski
Browse files

Add functional tests for grouped_gemm with different kbatch value.

parent 1945c26b
add_custom_target(test_grouped_gemm)
add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility device_grouped_gemm_instance)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_fp16 test_grouped_gemm_splitk)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
const std::vector<int> KBATCH{1, 2, 4, 6, 8, 10, 12, 14, 16, 32, 64, 128};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
#include "test_grouped_gemm_ut_cases.inc"
#pragma once
TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 4068);
const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 4068);
const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const std::vector<int> Ns(Ms.size(), 768);
const std::vector<int> Ks(Ms.size(), 4068);
const std::vector<int> StrideAs(Ms.size(), 4608);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 768);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const std::vector<int> Ns(Ms.size(), 4608);
const std::vector<int> Ks(Ms.size(), 384);
const std::vector<int> StrideAs(Ms.size(), 384);
const std::vector<int> StrideBs(Ms.size(), 4608);
const std::vector<int> StrideCs(Ms.size(), 4608);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
namespace ck {
namespace test {
template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
bool verify_ = true;
int init_method_ = 2; // decimal value initialization
bool log_ = false;
bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1)
{
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment