Commit 897e0bce authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed vector_load_size check

parent d1a50f9f
...@@ -615,11 +615,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -615,11 +615,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return false; return false;
} }
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector // If we use padding we do not support vector loads for dimensions not divisible by vector
// load size. // load size.
if constexpr(GemmSpec != GemmSpecialization::Default) // if constexpr(GemmSpec != GemmSpecialization::Default)
{ {
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout. // thus we have to adapt it to the {M,K} or {N,K} layout.
...@@ -631,12 +629,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -631,12 +629,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{}); const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{}); const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); if(a_vector_dim % ABlockTransferSrcScalarPerVector != 0)
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); ;
return false;
if(b_vector_dim % BBlockTransferSrcScalarPerVector != 0)
;
return false;
} }
} }
return supported; return true;
} }
// polymorphic // polymorphic
......
...@@ -3,13 +3,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -3,13 +3,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_grouped_gemm) add_gtest_executable(test_grouped_gemm test_grouped_gemm.cpp)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm PRIVATE utility device_grouped_gemm_instance)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) add_dependencies(test_grouped_gemm test_grouped_gemm_interface)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "test_grouped_gemm_util.hpp" #include "test_grouped_gemm_util.hpp"
class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test class TestGGemmInterface_MKNKMN : public ::testing::Test
{ {
protected: protected:
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -29,20 +29,20 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test ...@@ -29,20 +29,20 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock> ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
using GGemmInstance = using GGemmInstance =
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout, ck::test::DeviceGroupedGemmInstanceWrapper<ALayout,
BLayout, BLayout,
ELayout, ELayout,
GemmSpec, GemmSpec,
KPerBlock, KPerBlock,
K1, K1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>; CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>; using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
}; };
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize) TEST_F(TestGGemmInterface_MKNKMN, TileSize)
{ {
std::vector<int> Ms{128, 256, 188, 512}; std::vector<int> Ms{128, 256, 188, 512};
constexpr int N = 256; constexpr int N = 256;
...@@ -63,7 +63,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize) ...@@ -63,7 +63,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
} }
TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) TEST_F(TestGGemmInterface_MKNKMN, VectorLoadWidth)
{ {
static constexpr auto GemmMNKPadding = static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding; ck::tensor_operation::device::GemmSpecialization::MNKPadding;
...@@ -93,33 +93,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) ...@@ -93,33 +93,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
} }
#if 0 class TestGGemmInterface_KMKNNM : public ::testing::Test
TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
{
std::vector<int> Ms{128, 256, 256, 512};
constexpr int N = 256;
constexpr int K = 128;
constexpr int kbatch = 4;
std::vector<int> Ns(Ms.size(), N);
std::vector<int> Ks(Ms.size(), K);
std::vector<int> StrideAs(Ms.size(), K);
std::vector<int> StrideBs(Ms.size(), K);
std::vector<int> StrideCs(Ms.size(), N);
// kloops % 2
Ks = std::vector<int>{256, 512, 320, 768};
EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
// Not all gemms have same value for main_k0_block_loop!
Ks = std::vector<int>{256, 512, 512, 512};
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error);
}
#endif
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
{ {
protected: protected:
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -138,20 +112,20 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test ...@@ -138,20 +112,20 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock> ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
using GGemmInstance = using GGemmInstance =
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout, ck::test::DeviceGroupedGemmInstanceWrapper<ALayout,
BLayout, BLayout,
ELayout, ELayout,
GemmSpec, GemmSpec,
KPerBlock, KPerBlock,
K1, K1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
CDEBlockTransferScalarPerVector_NPerBlock>; CDEBlockTransferScalarPerVector_NPerBlock>;
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>; using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
}; };
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize) TEST_F(TestGGemmInterface_KMKNNM, TileSize)
{ {
std::vector<int> Ms{128, 256, 188, 512}; std::vector<int> Ms{128, 256, 188, 512};
constexpr int N = 256; constexpr int N = 256;
...@@ -172,7 +146,7 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize) ...@@ -172,7 +146,7 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
} }
TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth) TEST_F(TestGGemmInterface_KMKNNM, VectorLoadWidth)
{ {
static constexpr auto GemmMNKPadding = static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding; ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_splitk_util.hpp"
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
#include "test_grouped_gemm_ut_cases.inc"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_splitk_impl.hpp"
namespace ck {
namespace test {
template <typename Range>
std::string serialize_range(const Range& range)
{
std::stringstream ss;
for(auto& r : range)
{
ss << r << ", ";
}
std::string str = ss.str();
return std::string(str.begin(), str.end() - 2);
}
template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1)
{
bool pass = ck::profiler::profile_grouped_gemm_splitk_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch);
EXPECT_TRUE(pass);
}
};
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
32,
32,
4,
2,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
};
} // namespace test
} // namespace ck
...@@ -12,7 +12,7 @@ TEST_P(RRR_F16_F16_F16, TinyCases) ...@@ -12,7 +12,7 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const std::vector<int> StrideBs(Ms.size(), N); const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RRR_F16_F16_F16, SmallCases) TEST_P(RRR_F16_F16_F16, SmallCases)
...@@ -27,7 +27,7 @@ TEST_P(RRR_F16_F16_F16, SmallCases) ...@@ -27,7 +27,7 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const std::vector<int> StrideBs(Ms.size(), N); const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RRR_F16_F16_F16, MidCases) TEST_P(RRR_F16_F16_F16, MidCases)
...@@ -42,7 +42,7 @@ TEST_P(RRR_F16_F16_F16, MidCases) ...@@ -42,7 +42,7 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const std::vector<int> StrideBs(Ms.size(), N); const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RRR_F16_F16_F16, Regular) TEST_P(RRR_F16_F16_F16, Regular)
...@@ -57,7 +57,7 @@ TEST_P(RRR_F16_F16_F16, Regular) ...@@ -57,7 +57,7 @@ TEST_P(RRR_F16_F16_F16, Regular)
const std::vector<int> StrideBs(Ms.size(), N); const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RRR_F16_F16_F16, MNKPadded) TEST_P(RRR_F16_F16_F16, MNKPadded)
...@@ -72,7 +72,7 @@ TEST_P(RRR_F16_F16_F16, MNKPadded) ...@@ -72,7 +72,7 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const std::vector<int> StrideBs(Ms.size(), N); const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RCR_F16_F16_F16, TinyCases) TEST_P(RCR_F16_F16_F16, TinyCases)
...@@ -86,7 +86,7 @@ TEST_P(RCR_F16_F16_F16, TinyCases) ...@@ -86,7 +86,7 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
const std::vector<int> StrideAs(Ms.size(), K); const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RCR_F16_F16_F16, SmallCases) TEST_P(RCR_F16_F16_F16, SmallCases)
...@@ -101,7 +101,7 @@ TEST_P(RCR_F16_F16_F16, SmallCases) ...@@ -101,7 +101,7 @@ TEST_P(RCR_F16_F16_F16, SmallCases)
const std::vector<int> StrideBs(Ms.size(), K); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RCR_F16_F16_F16, MidCases) TEST_P(RCR_F16_F16_F16, MidCases)
...@@ -116,7 +116,7 @@ TEST_P(RCR_F16_F16_F16, MidCases) ...@@ -116,7 +116,7 @@ TEST_P(RCR_F16_F16_F16, MidCases)
const std::vector<int> StrideBs(Ms.size(), K); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RCR_F16_F16_F16, Regular) TEST_P(RCR_F16_F16_F16, Regular)
...@@ -131,7 +131,7 @@ TEST_P(RCR_F16_F16_F16, Regular) ...@@ -131,7 +131,7 @@ TEST_P(RCR_F16_F16_F16, Regular)
const std::vector<int> StrideBs(Ms.size(), K); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RCR_F16_F16_F16, MNKPadded) TEST_P(RCR_F16_F16_F16, MNKPadded)
...@@ -146,7 +146,7 @@ TEST_P(RCR_F16_F16_F16, MNKPadded) ...@@ -146,7 +146,7 @@ TEST_P(RCR_F16_F16_F16, MNKPadded)
const std::vector<int> StrideBs(Ms.size(), K); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
...@@ -161,7 +161,7 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) ...@@ -161,7 +161,7 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> StrideBs(Ms.size(), N); const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
...@@ -176,5 +176,5 @@ TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) ...@@ -176,5 +176,5 @@ TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> StrideBs(Ms.size(), K); const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N); const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
} }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
...@@ -85,7 +85,7 @@ template <typename ALayout, ...@@ -85,7 +85,7 @@ template <typename ALayout,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock> index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper struct DeviceGroupedGemmInstanceWrapper
{ {
using F16 = half_t; using F16 = half_t;
using F32 = float; using F32 = float;
...@@ -102,68 +102,67 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -102,68 +102,67 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
using I = ck::Number<N>; using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder = using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; std::conditional_t<std::is_same_v<ALayout, Row>, S<1, 0, 2>, S<0, 2, 1>>;
using ABlockTransferSrcAccessOrder = using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; std::conditional_t<std::is_same_v<ALayout, Row>, S<1, 0, 2>, S<0, 2, 1>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>; using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 = using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>; std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>; using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder = using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 2, 1>, S<1, 0, 2>>;
using BBlockTransferSrcAccessOrder = using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 2, 1>, S<1, 0, 2>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>; using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 = using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>; std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>; using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance = using DeviceGroupedGemmInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< tensor_operation::device::DeviceGroupedGemm_Xdl<ALayout,
ALayout, BLayout,
BLayout, EmptyTuple,
EmptyTuple, ELayout,
ELayout, F16,
F16, F16,
F16, F32,
F32, F16,
F16, EmptyTuple,
EmptyTuple, F16,
F16, PassThrough,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough, GemmSpec,
GemmSpec, 1,
1, 128,
128, 128,
128, 128,
128, KPerBlock,
KPerBlock, K1,
K1, K1,
K1, 32,
32, 32,
32, 4,
4, 2,
2, S<4, 16, 1>,
S<1, 4, 16, 1>, ABlockTransferThreadClusterArrageOrder,
ABlockTransferThreadClusterArrageOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim::value,
ABlockTransferSrcVectorDim::value, ABlockTransferSrcScalarPerVector,
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1::value,
ABlockTransferDstScalarPerVector_K1::value, ABlockLdsAddExtraM::value,
ABlockLdsAddExtraM::value, S<4, 16, 1>,
S<1, 4, 16, 1>, BBlockTransferThreadClusterArrageOrder,
BBlockTransferThreadClusterArrageOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim::value,
BBlockTransferSrcVectorDim::value, BBlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1::value,
BBlockTransferDstScalarPerVector_K1::value, BBlockLdsAddExtraM::value,
BBlockLdsAddExtraM::value, 1,
1, 1,
1, S<1, 16, 1, 8>,
S<1, 16, 1, 8>, CDEBlockTransferScalarPerVector_NPerBlock>;
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms, bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns, const std::vector<int>& Ns,
...@@ -190,7 +189,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -190,7 +189,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
std::vector<void*> p_Cs(n_groups, nullptr); std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{}; auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; auto ggemm_instance = DeviceGroupedGemmInstance{};
auto argument = ggemm_instance.MakeArgument( auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
...@@ -222,7 +221,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -222,7 +221,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
std::vector<void*> p_Cs(n_groups, nullptr); std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{}; auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; auto ggemm_instance = DeviceGroupedGemmInstance{};
auto argument = ggemm_instance.MakeArgument( auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment