Commit dec32dc6 authored by ThomasNing's avatar ThomasNing
Browse files

Finish the feature and merge with develop on the computeV2

parents 71352c44 c5fff071
...@@ -28,6 +28,7 @@ enum struct GemmDataType ...@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6 F16_F16_F16_F8, // 6
F8_F8_BF16, // 7 F8_F8_BF16, // 7
INT8_INT8_BF16, // 8 INT8_INT8_BF16, // 8
F8_F8_F16, // 9
}; };
#define OP_NAME "gemm_multiply_multiply" #define OP_NAME "gemm_multiply_multiply"
...@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, " "f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16)\n"); "comp f8; 8: int8->bf16; 9: f8->f16, comp f8;)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
...@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using F32 = float; using F32 = float;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using I8 = int8_t; using I8 = int8_t;
using I32 = int; using I32 = int;
...@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return profile( return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
} }
else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile( return profile(
......
...@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout ...@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
enum struct GemmDataType enum struct GemmDataType
{ {
BF16_I8_BF16, // 0 BF16_I8_BF16, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
F16_F8_F16, // 2 F16_F8_F16, // 2
F16_I8_F16, // 3 F16_I8_F16, // 3
BF16_BF16_BF16 // 4
}; };
#define OP_NAME "grouped_gemm_fixed_nk" #define OP_NAME "grouped_gemm_fixed_nk"
...@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input) ...@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
{ {
out.push_back(std::stoi(item)); out.push_back(std::stoi(item));
} }
return out; return out;
} }
...@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
#if defined(CK_ENABLE_FP8)
using F8 = ck::f8_t;
#endif
using BF16 = ck::bhalf_t;
using I8 = int8_t;
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
if(argc == 17) if(argc == 17)
...@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter = std::stoi(argv[16]); n_iter = std::stoi(argv[16]);
} }
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
I8, ck::half_t,
BF16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
I8, ck::half_t,
BF16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #if defined(CK_ENABLE_FP8)
#if defined(CK_ENABLE_FP16) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F16, ck::f8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F16, ck::f8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif // CK_ENABLE_FP8
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) #if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F8, int8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F8, int8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif // CK_ENABLE_INT8
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) #if defined(CK_ENABLE_BF16)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
I8, ck::bhalf_t,
F16, ck::bhalf_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
I8, int8_t,
F16, ck::bhalf_t,
F32, float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
1, kbatch,
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif // CK_ENABLE_INT8
#endif // CK_ENABLE_BF16
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
......
...@@ -21,16 +21,19 @@ dependencies = [] ...@@ -21,16 +21,19 @@ dependencies = []
"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues" "Bug Tracker" = "https://github.com/rocm/composable_kernel/issues"
[tool.setuptools] [tool.setuptools]
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"] packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library", "ck4inductor.universal_gemm", "ck4inductor.batched_universal_gemm", "ck4inductor.grouped_conv_fwd"]
[tool.setuptools.package-dir] [tool.setuptools.package-dir]
ck4inductor = "python/ck4inductor" ck4inductor = "python/ck4inductor"
"ck4inductor.universal_gemm" = "python/ck4inductor/universal_gemm"
"ck4inductor.batched_universal_gemm" = "python/ck4inductor/batched_universal_gemm"
"ck4inductor.grouped_conv_fwd" = "python/ck4inductor/grouped_conv_fwd"
"ck4inductor.include" = "include" "ck4inductor.include" = "include"
"ck4inductor.library" = "library" "ck4inductor.library" = "library"
[tool.setuptools.package-data] [tool.setuptools.package-data]
"ck4inductor.include" = ["ck/**/*.hpp"] "ck4inductor.include" = ["ck/**/*.hpp"]
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"] "ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp", "src/tensor_operation_instance/gpu/gemm_universal_batched/**/*.hpp", "include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/**/*.hpp"]
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = { attr = "setuptools_scm.get_version" } version = { attr = "setuptools_scm.get_version" }
...@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: ...@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args.insert(2, tuple()) # ds layout template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype template_args.insert(6, tuple()) # ds dtype
try:
new_instance = CKGemmOperation( new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type] *template_args, # type: ignore[arg-type]
) )
op_instances.append(new_instance)
op_instances.append(new_instance) except TypeError as e:
log.debug(f"{e} when parsing {line}")
return op_instances return op_instances
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import logging
import unittest
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_library as gen_gemm_ops_library,
)
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_preselected as gen_gemm_ops_preselected,
)
from ck4inductor.grouped_conv_fwd.gen_instances import (
gen_conv_ops_library as gen_conv_ops_library,
)
from ck4inductor.batched_universal_gemm.gen_instances import (
gen_ops_library as gen_batched_gemm_ops_library,
)
log = logging.getLogger(__name__)
class TestGenInstances(unittest.TestCase):
def test_gen_gemm_instances(self):
instances = gen_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_preselected_gemm_instances(self):
instances = gen_gemm_ops_preselected()
log.debug("%d preselected gemm instances" % len(instances))
self.assertTrue(instances)
def test_gen_conv_instances(self):
instances = gen_conv_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_gen_batched_gemm_instances(self):
instances = gen_batched_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
...@@ -7,6 +7,34 @@ include(gtest) ...@@ -7,6 +7,34 @@ include(gtest)
add_custom_target(tests) add_custom_target(tests)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set(REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
set(result 1) set(result 1)
...@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME) ...@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif() endif()
#message("add_test returns ${result}") #message("add_test returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("adding to SMOKE TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction() endfunction()
function(add_gtest_executable TEST_NAME) function(add_gtest_executable TEST_NAME)
...@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME) ...@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif() endif()
#message("add_gtest returns ${result}") #message("add_gtest returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction() endfunction()
add_compile_options(-Wno-c++20-extensions) add_compile_options(-Wno-c++20-extensions)
......
...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>, // std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>, //std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//, std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16> //std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <sstream> #include <sstream>
...@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
...@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; ...@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, // using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; // ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; // using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>; using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> // std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
......
...@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM) ...@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr int K = 320; constexpr int K = 320;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); {
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
else
this->Run(M, N, K);
}
} }
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
constexpr int K = 320; constexpr int K = 320;
constexpr int VecLoadSize = 8;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); {
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// TODO: Can we anyhow deduce used vector load size?
if(M % VecLoadSize == 0)
this->Run(M, N, K);
else
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
} }
TYPED_TEST(TestCkTileGemmPipeline, PaddK) TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{128};
constexpr int N = 1024; constexpr int N = 1024;
constexpr int K = 432; constexpr int K = 432;
......
...@@ -16,6 +16,7 @@ enum struct GemmPipelineType ...@@ -16,6 +16,7 @@ enum struct GemmPipelineType
Mem, Mem,
Comp Comp
}; };
template <typename Tuple> template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test class TestCkTileGemmPipeline : public ::testing::Test
{ {
...@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN; constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK; constexpr bool kPadK = PadK;
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// =============================================== // ===============================================
...@@ -59,20 +63,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -59,20 +63,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t< using BaseGemmPipeline =
PipelineType == GemmPipelineType::Mem, std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>, ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
std::conditional_t<PipelineType == GemmPipelineType::Mem, BDataType,
ck_tile::GemmPipelineAgBgCrMem< AccDataType,
ck_tile::UniversalGemmPipelineProblem<ADataType, GemmShape,
BDataType, GemmUniversalTraits,
AccDataType, Scheduler,
GemmShape, has_hot_loop_v,
Traits, tail_number_v>;
Scheduler,
has_hot_loop_v, using GemmPipeline = std::conditional_t<
tail_number_v>>, PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrCompV3< ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
BDataType, ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
AccDataType, ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
...@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
// Tail pipeline One to Seven if constexpr(PipelineType == GemmPipelineType::Comp)
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Full)
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{}); ck_tile::TailNumber::Full>{});
} }
} else
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{ {
Run(ck_tile::bool_constant<true>{}, std::ostringstream err;
ck_tile::integral_constant<ck_tile::TailNumber, err << "For compute pipeline tail number should always be Full, but have \""
ck_tile::TailNumber::Three>{}); << tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
if constexpr(PipelineType == GemmPipelineType::Mem)
{ {
if(tail_num == ck_tile::TailNumber::Four) // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{}); ck_tile::TailNumber::One>{});
} }
} else if(tail_num == ck_tile::TailNumber::Full)
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{}); ck_tile::TailNumber::Full>{});
} }
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6) if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Six)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Two)
ck_tile::integral_constant<ck_tile::TailNumber, {
ck_tile::TailNumber::Six>{}); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
} }
} if constexpr(BaseGemmPipeline::PrefetchStages > 3)
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Three)
ck_tile::integral_constant<ck_tile::TailNumber, {
ck_tile::TailNumber::Seven>{}); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Seven>{});
}
} }
} }
} }
......
...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>, // std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>, //std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//, std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16> //std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
...@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test ...@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
CodegenGemmShape, CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>; CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>, ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
CodegenGemmPolicy>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
......
...@@ -49,3 +49,4 @@ if(result EQUAL 0) ...@@ -49,3 +49,4 @@ if(result EQUAL 0)
endif() endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_type_convert_const type_convert_const.cpp)
add_gtest_executable(test_bhalf test_bhalf.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf_t;
using ck::type_convert;
TEST(BHALF_T, Nan)
{
const uint16_t binary_bhalf_nan = 0x7FC0;
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
}
TEST(BHALF_T, Inf)
{
const uint16_t binary_bhalf_inf = 0x7F80;
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
}
TEST(BHALF_T, MantisaOverflow)
{
const float abs_tol = std::pow(2, -7);
const uint32_t val = 0x81FFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
}
TEST(BHALF_T, ExpOverflow)
{
const uint32_t val = 0xFF800000;
const float float_val = ck::bit_cast<float>(val);
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
}
TEST(BHALF_T, MantisaExpOverflow)
{
const uint32_t val = 0xFFFFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_TRUE(std::isnan(float_val));
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment