Commit 1b616990 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8_uc

parents af30d6b6 800cf897
......@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
try:
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
except TypeError as e:
log.debug(f"{e} when parsing {line}")
return op_instances
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import logging
import unittest
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_library as gen_gemm_ops_library,
)
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_preselected as gen_gemm_ops_preselected,
)
from ck4inductor.grouped_conv_fwd.gen_instances import (
gen_conv_ops_library as gen_conv_ops_library,
)
from ck4inductor.batched_universal_gemm.gen_instances import (
gen_ops_library as gen_batched_gemm_ops_library,
)
log = logging.getLogger(__name__)
class TestGenInstances(unittest.TestCase):
def test_gen_gemm_instances(self):
instances = gen_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_preselected_gemm_instances(self):
instances = gen_gemm_ops_preselected()
log.debug("%d preselected gemm instances" % len(instances))
self.assertTrue(instances)
def test_gen_conv_instances(self):
instances = gen_conv_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_gen_batched_gemm_instances(self):
instances = gen_batched_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
......@@ -15,7 +15,7 @@ else
fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
......
......@@ -149,6 +149,12 @@ def parse_logfile(logfile):
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
elif 'perf_tile_gemm_basic' in logfile or 'perf_tile_gemm_mem_pipeline' in logfile:
for line in open(logfile):
if 'TFlops' in line:
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
return res
......@@ -330,6 +336,14 @@ def main():
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_fmha_bwd_tflops"
if 'gemm_basic_fp16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_basic_fp16_tflops"
if 'gemm_mem_pipeline_fp16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_mem_pipeline_fp16_tflops"
tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
......
......@@ -43,3 +43,19 @@ file=./perf_fmha_bwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file=./perf_tile_gemm_basic_fp16_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx942.log
fi
file=./perf_tile_gemm_basic_fp16_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx90a.log
fi
file=./perf_tile_gemm_mem_pipeline_fp16_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx942.log
fi
file=./perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
fi
......@@ -52,3 +52,19 @@ file=./perf_fmha_bwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file=./perf_gemm_basic_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_basic_gfx942.log
fi
file=./perf_gemm_basic_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_basic_gfx90a.log
fi
file=./perf_gemm_mem_pipeline_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log
fi
file=./perf_gemm_mem_pipeline_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log
fi
......@@ -7,6 +7,34 @@ include(gtest)
add_custom_target(tests)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set(REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}")
set(result 1)
......@@ -43,6 +71,12 @@ function(add_test_executable TEST_NAME)
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
foreach(source IN LISTS ARGN)
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
......@@ -66,7 +100,7 @@ function(add_test_executable TEST_NAME)
if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
endif()
......@@ -82,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif()
#message("add_test returns ${result}")
set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("adding to SMOKE TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction()
function(add_gtest_executable TEST_NAME)
......@@ -126,26 +169,38 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx95" AND source MATCHES "mx_")
message("removing microscaling test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(ARGN MATCHES "_mx") #only build mx example for gfx950
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN})
......@@ -162,6 +217,15 @@ function(add_gtest_executable TEST_NAME)
endif()
#message("add_gtest returns ${result}")
set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction()
add_compile_options(-Wno-c++20-extensions)
......@@ -206,8 +270,11 @@ add_subdirectory(wrapper)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2
add_subdirectory(smfmac_op)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_subdirectory(mx_mfma_op)
endif()
add_subdirectory(position_embedding)
add_subdirectory(scatter_gather)
......@@ -2,3 +2,4 @@ add_subdirectory(image_to_column)
add_subdirectory(gemm)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(data_type)
......@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
......@@ -32,9 +32,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
......@@ -51,32 +48,12 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
......@@ -88,12 +65,26 @@ class TestCkTileBatchedGemm : public ::testing::Test
CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenGemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
......@@ -186,6 +177,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = 1;
args.M = M;
args.N = N;
args.K = K;
......
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck_tile/core.hpp"
using ck_tile::bf16_t;
using ck_tile::bf16x2_t;
using ck_tile::fp16x2_t;
using ck_tile::fp32x2_t;
using ck_tile::half_t;
using ck_tile::pk_int4_t;
TEST(PackedInt4, ConvertToFloat)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
constexpr float first_input_val = 7.f;
constexpr float second_input_val = -1.f;
#else
constexpr float first_input_val = -1.f;
constexpr float second_input_val = 7.f;
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
fp32x2_t out = ck_tile::pk_int4_t_to_fp32x2_t(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToHalf)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
const half_t first_input_val = ck_tile::type_convert<half_t>(7.f);
const half_t second_input_val = ck_tile::type_convert<half_t>(-1.f);
#else
const half_t first_input_val = ck_tile::type_convert<half_t>(-1.f);
const half_t second_input_val = ck_tile::type_convert<half_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
fp16x2_t out = ck_tile::pk_int4_t_to_halfx2_t(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToBHalf)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(7.f);
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(-1.f);
#else
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(-1.f);
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
bf16x2_t out = ck_tile::pk_int4_t_to_bfloat16x2_t(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
......@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>;
// clang-format on
......
......@@ -10,7 +10,13 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr int K = 320;
for(int M : Ms)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
else
this->Run(M, N, K);
}
}
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
......@@ -18,14 +24,29 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
constexpr int VecLoadSize = 8;
for(int M : Ms)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// TODO: Can we anyhow deduce used vector load size?
if(M % VecLoadSize == 0)
this->Run(M, N, K);
else
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
}
TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{
std::vector<int> Ms{127};
std::vector<int> Ms{128};
constexpr int N = 1024;
constexpr int K = 432;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment