Commit b6f7cddd authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge remote-tracking branch 'origin/develop' into gfx950

parents 261d76c4 5dff1b14
...@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadA = true; constexpr bool kPadM = true;
constexpr bool kPadB = true; constexpr bool kPadN = true;
constexpr bool kPadC = true; constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
...@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Lunching kernel with args:" std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl; << "}" << std::endl;
......
...@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK ...@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
using KernelTypes_MK_KN = ::testing::Types< using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>, std::tuple< F8, F8, F8, BF16>,
...@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types< ...@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
using KernelTypes_MK_NK = ::testing::Types< using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>, std::tuple< F8, F8, F8, BF16>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment