"...llama_fastertransformer.git" did not exist on "a929d1c69a0b3e08ed7e70c752bdc2874d92a256"
Unverified Commit d3cd6f41 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-987

parents e84c2a33 98fd41f5
...@@ -10,9 +10,12 @@ ...@@ -10,9 +10,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -20,49 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -20,49 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
struct MemoryParams struct Dimensions
{ {
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
}; };
template <typename Tuple> template <typename Tuple>
class TestContraction : public ::testing::Test class TestContraction : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CDLayout = std::tuple_element_t<2, Tuple>; using CDLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>; using DataType = std::tuple_element_t<3, Tuple>;
using DTupleDataType = std::tuple_element_t<4, Tuple>; using DTupleDataType = std::tuple_element_t<4, Tuple>;
using CDElementOp = std::tuple_element_t<5, Tuple>; using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<MemoryParams> list_of_memory_params = {{{32, 32},
{32, 32}, std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
{32, 32}, {{16, 16}, {32, 32}, {16, 16}}};
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}, std::vector<ck::index_t> init_methods = {1, 2};
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}},
{{16, 16},
{32, 32},
{16, 16},
{4096, 256, 16, 1},
{16, 1, 8192, 256},
{16384, 1024, 32, 1},
{16384, 1024, 32, 1}}};
std::vector<ck::index_t> init_methods = {0, 1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op; std::unique_ptr<CDElementOp> p_cd_element_op;
void Run() void Run()
{ {
for(auto& memory_params : list_of_memory_params) for(auto& dimension_params : dimension_list)
{ {
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods) for(const ck::index_t init_method : init_methods)
{ {
bool pass = bool pass =
...@@ -70,19 +73,20 @@ class TestContraction : public ::testing::Test ...@@ -70,19 +73,20 @@ class TestContraction : public ::testing::Test
BLayout, BLayout,
CDLayout, CDLayout,
DataType, DataType,
ComputeDataType,
DTupleDataType, DTupleDataType,
CDElementOp>(true /*do_verification*/, CDElementOp>(true /*do_verification*/,
init_method, init_method,
false /*do_logs*/, false /*do_logs*/,
false /*time_kernel*/, false /*time_kernel*/,
*p_cd_element_op, *p_cd_element_op,
memory_params.M, dimension_params.M,
memory_params.N, dimension_params.N,
memory_params.K, dimension_params.K,
memory_params.StridesA, StridesA,
memory_params.StridesB, StridesB,
memory_params.StridesC, StridesC,
memory_params.StridesD); StridesD);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
} }
...@@ -99,24 +103,18 @@ class TestContractionBilinear : public TestContraction<Tuple> ...@@ -99,24 +103,18 @@ class TestContractionBilinear : public TestContraction<Tuple>
{ {
}; };
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using BilinearKernelTypes = using BilinearKernelTypes =
::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<F32>, Bilinear>, ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
std::tuple<Row, Col, Row, F32, ck::Tuple<F32>, Bilinear>, ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;
std::tuple<Col, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F32, ck::Tuple<F32>, Bilinear>, using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale),
std::tuple<Row, Row, Row, F64, ck::Tuple<F32>, Bilinear>, ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>;
std::tuple<Row, Col, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F64, ck::Tuple<F32>, Bilinear>>;
using ScaleKernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F64, ck::Tuple<>, Scale>>;
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
...@@ -136,3 +134,46 @@ TYPED_TEST(TestContractionScale, scale) ...@@ -136,3 +134,46 @@ TYPED_TEST(TestContractionScale, scale)
this->p_cd_element_op = std::make_unique<Scale>(0.5f); this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run(); this->Run();
} }
template <typename Tuple>
class TestContractionScaleMixedPrecision : public TestContraction<Tuple>
{
};
template <typename Tuple>
class TestContractionBilinearMixedPrecision : public TestContraction<Tuple>
{
};
using BilinearKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, BF16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<F16>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<BF16>, F32, Bilinear)>;
using ScaleKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F16, Scale),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>;
TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision);
TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision);
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
}
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
}
...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper ...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper
static constexpr ck::index_t NumDim = 2; static constexpr ck::index_t NumDim = 2;
// clang-format off // clang-format off
using ContractionDeviceInstance = ck::tensor_operation::device:: using ContractionDeviceInstance = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector, F32>;
// clang-format on // clang-format on
bool isSupported(std::vector<ck::index_t>& ADims, bool isSupported(std::vector<ck::index_t>& ADims,
......
...@@ -108,6 +108,10 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) ...@@ -108,6 +108,10 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
// kloops % 2 // kloops % 2
Ks = std::vector<int>{256, 512, 320, 768}; Ks = std::vector<int>{256, 512, 320, 768};
EXPECT_FALSE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
Ks = std::vector<int>{256, 512, 384, 768};
EXPECT_TRUE( EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch)); DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment