Commit efab74a3 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'gfx950' into lwpck-2619

parents 86950b3a bcef33c1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "profiler/profile_gemm_b_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
F16_I4_F16, // 8
};
enum struct BScaleBlockTile
{
K_64, // 0
K_128, // 1
};
#define OP_NAME "gemm_b_scale"
#define OP_DESC "Int4-dequant GEMM"
int profile_gemm_b_scale(int argc, char* argv[])
{
if(argc != 16 && argc != 19)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: f16@i4)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: B scale block tile (0: 64, 1: 128):\n");
printf("arg5: verification (0: no; 1: yes)\n");
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg7: print tensor value (0: no; 1: yes)\n");
printf("arg8: time kernel (0=no, 1=yes)\n");
printf("arg9 to 14: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg15: split k into mulitiple batch\n");
printf("optional:\n");
printf("arg16: number of warm-up cycles (default 1)\n");
printf("arg17: number of iterations (default 10)\n");
printf("arg18: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
printf("Start profiling\n");
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto B_scale_block = static_cast<BScaleBlockTile>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const int init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const int M = std::stoi(argv[9]);
const int N = std::stoi(argv[10]);
const int K = std::stoi(argv[11]);
const int StrideA = std::stoi(argv[12]);
const int StrideB = std::stoi(argv[13]);
const int StrideC = std::stoi(argv[14]);
const int KBatch = std::stoi(argv[15]);
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch);
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc == 19)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
rotating = std::stoull(argv[18]) * 1024 * 1024;
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
}
using F32 = float;
using F16 = ck::half_t;
using I4 = ck::pk_i4_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto b_scale_type,
auto comp_type,
auto acc_type,
auto c_type,
auto scale_block_k,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using BScaleDataType = decltype(b_scale_type);
using ComputeDataType = decltype(comp_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_b_scale_impl<ADataType,
BDataType,
BScaleDataType,
ComputeDataType,
AccDataType,
CDataType,
scale_block_k,
ALayout,
BLayout,
CLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch,
n_warmup,
n_iter,
rotating);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
B_scale_block == BScaleBlockTile::K_128)
{
printf("F16_I4_F16 MK_NK_MN K_128\n");
return profile(
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_b_scale);
......@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
INT8_INT8_BF16, // 8
F8_F8_F16, // 9
};
#define OP_NAME "gemm_multiply_multiply"
......@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16)\n");
"comp f8; 8: int8->bf16; 9: f8->f16, comp f8;)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using I32 = int;
......@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_impl.hpp"
#include "profiler_operation_registry.hpp"
......@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
F16_I4_F16, // 8
BF16_I4_BF16, // 9
};
#define OP_NAME "gemm_universal"
......@@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8)\n");
"comp f8; 8: f16@i4; 9: bf16@i4\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[])
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
using I4 = ck::pk_i4_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor;
......@@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
#endif
else
{
......
......@@ -85,6 +85,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
......@@ -165,6 +166,22 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
#endif
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
......
......@@ -21,16 +21,19 @@ dependencies = []
"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"]
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library", "ck4inductor.universal_gemm", "ck4inductor.batched_universal_gemm", "ck4inductor.grouped_conv_fwd"]
[tool.setuptools.package-dir]
ck4inductor = "python/ck4inductor"
"ck4inductor.universal_gemm" = "python/ck4inductor/universal_gemm"
"ck4inductor.batched_universal_gemm" = "python/ck4inductor/batched_universal_gemm"
"ck4inductor.grouped_conv_fwd" = "python/ck4inductor/grouped_conv_fwd"
"ck4inductor.include" = "include"
"ck4inductor.library" = "library"
[tool.setuptools.package-data]
"ck4inductor.include" = ["ck/**/*.hpp"]
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp", "src/tensor_operation_instance/gpu/gemm_universal_batched/**/*.hpp", "include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/**/*.hpp"]
[tool.setuptools.dynamic]
version = { attr = "setuptools_scm.get_version" }
......@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
try:
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
except TypeError as e:
log.debug(f"{e} when parsing {line}")
return op_instances
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import logging
import unittest
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_library as gen_gemm_ops_library,
)
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_preselected as gen_gemm_ops_preselected,
)
from ck4inductor.grouped_conv_fwd.gen_instances import (
gen_conv_ops_library as gen_conv_ops_library,
)
from ck4inductor.batched_universal_gemm.gen_instances import (
gen_ops_library as gen_batched_gemm_ops_library,
)
log = logging.getLogger(__name__)
class TestGenInstances(unittest.TestCase):
def test_gen_gemm_instances(self):
instances = gen_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_preselected_gemm_instances(self):
instances = gen_gemm_ops_preselected()
log.debug("%d preselected gemm instances" % len(instances))
self.assertTrue(instances)
def test_gen_conv_instances(self):
instances = gen_conv_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_gen_batched_gemm_instances(self):
instances = gen_batched_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
......@@ -15,7 +15,7 @@ else
fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
......
......@@ -149,6 +149,12 @@ def parse_logfile(logfile):
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
elif 'perf_tile_gemm_basic' in logfile or 'perf_tile_gemm_mem_pipeline' in logfile:
for line in open(logfile):
if 'TFlops' in line:
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
return res
......@@ -330,6 +336,14 @@ def main():
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_fmha_bwd_tflops"
if 'gemm_basic_fp16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_basic_fp16_tflops"
if 'gemm_mem_pipeline_fp16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_mem_pipeline_fp16_tflops"
tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
......
......@@ -43,3 +43,19 @@ file=./perf_fmha_bwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file=./perf_tile_gemm_basic_fp16_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx942.log
fi
file=./perf_tile_gemm_basic_fp16_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx90a.log
fi
file=./perf_tile_gemm_mem_pipeline_fp16_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx942.log
fi
file=./perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
fi
......@@ -52,3 +52,19 @@ file=./perf_fmha_bwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file=./perf_gemm_basic_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_basic_gfx942.log
fi
file=./perf_gemm_basic_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_basic_gfx90a.log
fi
file=./perf_gemm_mem_pipeline_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log
fi
file=./perf_gemm_mem_pipeline_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log
fi
......@@ -7,6 +7,34 @@ include(gtest)
add_custom_target(tests)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set(REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}")
set(result 1)
......@@ -43,6 +71,12 @@ function(add_test_executable TEST_NAME)
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
foreach(source IN LISTS ARGN)
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
......@@ -82,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif()
#message("add_test returns ${result}")
set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("adding to SMOKE TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction()
function(add_gtest_executable TEST_NAME)
......@@ -162,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif()
#message("add_gtest returns ${result}")
set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction()
add_compile_options(-Wno-c++20-extensions)
......@@ -206,7 +258,7 @@ add_subdirectory(wrapper)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2
add_subdirectory(smfmac_op)
endif()
add_subdirectory(position_embedding)
......
......@@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
......@@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
......@@ -185,21 +182,23 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(kargs,
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = 1;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = StrideA;
args.stride_B = StrideB;
args.stride_C = StrideC;
args.batch_stride_A = BatchStrideA;
args.batch_stride_B = BatchStrideB;
args.batch_stride_C = BatchStrideC;
args.batch_count = BatchCount;
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
ck_tile::stream_config{nullptr, false});
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
......
......@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
// TODO: expose tile size through test t-param ?
struct gemm_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
template <bool PadM, bool PadN, bool PadK>
void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s)
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128;
......@@ -88,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
......@@ -117,17 +105,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v,
tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
......@@ -319,11 +299,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
gemm_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
......
......@@ -12,6 +12,7 @@ endif()
add_custom_target(test_fp8)
if (CK_USE_OCP_FP8)
# add test for ocp data types
add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_ocp PRIVATE utility)
......@@ -62,16 +63,28 @@ if(GPU_TARGETS MATCHES "gfx950")
endif()
add_dependencies(test_mx_data_types test_bf6)
add_gtest_executable(test_mx_fp8 test_mx_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_fp8 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_mx_fp8)
add_gtest_executable(test_mx_bf8 test_mx_bf8.cpp)
if(result EQUAL 0)
target_link_libraries(test_mx_bf8 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_mx_bf8)
add_gtest_executable(test_e8m0 test_e8m0.cpp)
if(result EQUAL 0)
target_link_libraries(test_e8m0 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_e8m0)
endif()
add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0)
target_link_libraries(test_custom_type PRIVATE utility)
endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
add_gtest_executable(test_bhalf test_bhalf.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf_t;
using ck::type_convert;
TEST(BHALF_T, Nan)
{
const uint16_t binary_bhalf_nan = 0x7FC0;
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
}
TEST(BHALF_T, Inf)
{
const uint16_t binary_bhalf_inf = 0x7F80;
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
}
TEST(BHALF_T, MantisaOverflow)
{
const float abs_tol = std::pow(2, -7);
const uint32_t val = 0x81FFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
}
TEST(BHALF_T, ExpOverflow)
{
const uint32_t val = 0xFF800000;
const float float_val = ck::bit_cast<float>(val);
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
}
TEST(BHALF_T, MantisaExpOverflow)
{
const uint32_t val = 0xFFFFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_TRUE(std::isnan(float_val));
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
}
......@@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest)
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
// positive subnorm fp8 value to fp8 and back, check if holds
pos_float = 0.00390625f; // 2^-8
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::bf8_ocp_t;
using ck::bf8x16_ocp_t;
using ck::bf8x2_ocp_t;
using ck::bf8x32_ocp_t;
using ck::e8m0_bexp_t;
using ck::float16_t;
using ck::float2_t;
using ck::float32_t;
using ck::mxf8_convert_rne;
using ck::mxf8_convert_sr;
using ck::scaled_type_convert;
using ck::type_convert;
constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6;
/**
* @brief Tests conversion of BF8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from BF8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and BF8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and BF8 values that are
* stored in memory sequentially with BF8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and BF8 values. [256x256]
* - Vector conversions bf8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> bf8x2 rne. [2]
* - Vector conversions f32x2 -> bf8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__ __device__ void
test_mx_bf8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
// All possible combinations of E8M0 and BF8
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
}
/// Test vector conversions
// bf8x2 -> f32x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
auto scale = e8m0_bexp_t(8.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale, bf8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// f32x2 -> bf8x2
f32x2 = {-8.0f, 4.0f};
auto scale2 = e8m0_bexp_t(2.0f);
bf8x2 = mxf8_convert_rne<bf8x2_ocp_t>(f32x2, type_convert<float>(scale2)); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
auto scale4 = e8m0_bexp_t(4.0f);
bf8x2 = mxf8_convert_sr<bf8x2_ocp_t>(f32x2, type_convert<float>(scale4)); // expect {-2, 1}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-2f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 1f
if(i >= N)
{
return;
}
/// Test round to nearest even
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(1024.0f, 4.0f)); // 1024/4
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(
std::numeric_limits<float>::infinity(), 2.0f)); // => BF8 Inf on device
if(i >= N)
{
return;
}
// 31000/0.5 > 57344 => BF8 Inf on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(31000.0f, 0.5f));
if(i >= N)
{
return;
}
// -31000/0.5 < -57344 => -BF8 Inf on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<bf8_ocp_t>(-31000.0f, 0.5f));
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<bf8_ocp_t>(powf(2.0f, 16.0f), 4.0f)); // 2^16/4 = 65536/4
if(i >= N)
{
return;
}
}
TEST(MXBF8, HostScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_mx_bf8_scaled_convert(test_size, out.data(), &completed);
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
auto idx = e8m0_nan_id * 256 + bf8_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> bf8_spec_ids;
bf8_spec_ids.insert(0b11111111); // -NaN
bf8_spec_ids.insert(0b01111111); // +NaN
bf8_spec_ids.insert(0b11111101); // -NaN
bf8_spec_ids.insert(0b01111101); // +NaN
bf8_spec_ids.insert(0b11111110); // -NaN
bf8_spec_ids.insert(0b01111110); // +NaN
bf8_spec_ids.insert(0b11111100); // -inf
bf8_spec_ids.insert(0b01111100); // +inf
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto bf8_spec_id : bf8_spec_ids)
{
auto idx = exp_id * 256 + bf8_spec_id;
if(std::isnan(type_convert<float>(bf8_ocp_t{bf8_spec_id})))
{
ASSERT_TRUE(std::isnan(out[idx]))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
else
{
ASSERT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_spec_id}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
}
}
// V = X * P; X, P - finite
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = exp_id * 256 + bf8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(bf8_ocp_t{bf8_uid}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// bf8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -11.0f));
EXPECT_EQ(out[i++], powf(2.0f, -13.0f));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void test_mx_bf8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_bf8_scaled_convert(N, p_test, p_completed);
}
TEST(MXBF8, DeviceScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8_device_scaled_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
auto idx = e8m0_nan_id * 256 + bf8_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> bf8_spec_ids;
bf8_spec_ids.insert(0b11111111); //-NaN
bf8_spec_ids.insert(0b01111111); // +NaN
bf8_spec_ids.insert(0b11111101); //-NaN
bf8_spec_ids.insert(0b01111101); // +NaN
bf8_spec_ids.insert(0b11111110); //-NaN
bf8_spec_ids.insert(0b01111110); // +NaN
bf8_spec_ids.insert(0b11111100); //-inf
bf8_spec_ids.insert(0b01111100); // +inf
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto bf8_spec_id : bf8_spec_ids)
{
auto idx = exp_id * 256 + bf8_spec_id;
if(std::isnan(type_convert<float>(bf8_ocp_t{bf8_spec_id})))
{
ASSERT_TRUE(std::isnan(out[idx]))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
else
{
ASSERT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_spec_id}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx];
}
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = exp_id * 256 + bf8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(bf8_ocp_t{bf8_uid}))
<< "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// bf8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -11.0f));
EXPECT_EQ(out[i++], powf(2.0f, -13.0f));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1];
#else
// NOTE: Host and Device have different behavior.
// Device returns Infs, while Host returns Max (saturation to finite value).
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ float vec16_generator(ck::index_t i) { return powf(-1.0f, i) * powf(2.0f, i); }
__global__ void test_mx_bf8x16_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 16;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
bf8x16_ocp_t bf8x16{};
float16_t float16{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float16[static_cast<int>(ii)] = vec16_generator(ii); });
bf8x16 = scaled_type_convert<bf8x16_ocp_t>(scale2, float16);
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x16.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x16ToBF8x16ScaledConvert)
{
constexpr int N = 16;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x16_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
}
else
{
return 1.5f * vec16_generator(i % 16);
}
}
__global__ void test_mx_bf8x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
bf8x32 = mxf8_convert_rne<bf8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x32.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_bf8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(8.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
bf8x32 = mxf8_convert_sr<bf8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(bf8x32.AsType<bf8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvertSR)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_bf8x32_device_scaled_convert_sr<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
bf8x32_ocp_t bf8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}([&](auto ii) {
bf8x32.AsType<bf8_ocp_t>()(ii) = type_convert<bf8_ocp_t>(vec32_generator(ii) / 16.0f);
});
float32 = scaled_type_convert<float32_t>(scale2, bf8x32);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
}
TEST(MXBF8, DeviceBF8x32ToF32x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using ck::e8m0_bexp_t;
using ck::f8_ocp_t;
using ck::f8x16_ocp_t;
using ck::f8x2_ocp_t;
using ck::f8x32_ocp_t;
using ck::float16_t;
using ck::float2_t;
using ck::float32_t;
using ck::mxf8_convert_rne;
using ck::mxf8_convert_sr;
using ck::scaled_type_convert;
using ck::type_convert;
using ck::fp8_impl::fp8x2_storage_t;
constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6;
/**
* @brief Tests conversion of FP8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from FP8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and FP8 values that are
* stored in memory sequentially with FP8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and FP8 values. [256x256]
* - Vector conversions f8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> f8x2 rne. [2]
* - Vector conversions f32x2 -> f8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__ __device__ void
test_mx_fp8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
// All possible combinations of E8M0 and FP8
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
}
/// Test vector conversions
// f8x2 -> f32x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
auto scale2 = e8m0_bexp_t(2.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale2, fp8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// f32x2 -> f8x2
f32x2 = {-8.0f, 4.0f};
fp8x2 = mxf8_convert_rne<f8x2_ocp_t>(f32x2, type_convert<float>(scale2)); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
auto scale4 = e8m0_bexp_t(4.0f);
fp8x2 = mxf8_convert_sr<f8x2_ocp_t>(f32x2, type_convert<float>(scale4)); // expect {-2, 1}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-2f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 1f
if(i >= N)
{
return;
}
/// Test round to nearest even
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(1024.0f, 4.0f)); // 1024/4
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(
mxf8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
if(i >= N)
{
return;
}
// Inf/2 > 448 => NaN on device
p_test[i++] = type_convert<float>(
mxf8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::infinity(), 2.0f));
if(i >= N)
{
return;
}
// 256/0.5 > 448 => NaN on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(256.0f, 0.5f));
if(i >= N)
{
return;
}
// -256/0.5 < -448 => NaN on device
p_test[i++] = type_convert<float>(mxf8_convert_rne<f8_ocp_t>(-256.0f, 0.5f));
if(i >= N)
{
return;
}
// proper scale selection 2^13 < 10000; 2^8 < 448 => scale = 2^(13-8) = 2^5
p_test[i++] =
type_convert<float>(mxf8_convert_rne<f8_ocp_t>(10000.0f, 32.0f)); // 10000/32 = 312.5
if(i >= N)
{
return;
}
}
TEST(MXFP8, HostScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_mx_fp8_scaled_convert(test_size, out.data(), &completed);
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
auto idx = e8m0_nan_id * 256 + fp8_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = exp_id * 256 + fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = exp_id * 256 + fp8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f8_ocp_t{fp8_uid}))
<< "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// f8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f8_ocp_t>(312.5f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void test_mx_fp8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_fp8_scaled_convert(N, p_test, p_completed);
}
TEST(MXFP8, DeviceScaledConvert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8_device_scaled_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
auto idx = e8m0_nan_id * 256 + fp8_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
// If P in {Inf, NaN}, then V = P
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = exp_id * 256 + fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
}
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
{
if(exp_id == e8m0_nan_id)
continue;
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = exp_id * 256 + fp8_uid;
ASSERT_FLOAT_EQ(out[idx],
type_convert<float>(e8m0_bexp_t(exp_id)) *
type_convert<float>(f8_ocp_t{fp8_uid}))
<< "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
}
/// Test vector conversions
auto i = 256 * 256;
// f8x2 -> f32x2
EXPECT_EQ(out[i++], -powf(2.0f, -5.0f));
EXPECT_EQ(out[i++], powf(2.0f, -8.0f));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -2.0f);
EXPECT_EQ(out[i++], 1.0f);
/// Test round to nearest even
EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#if 1
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1];
#else
// NOTE: Host and Device have different behavior.
// Device returns NaN, while Host returns Max (saturation to finite value).
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max()))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f8_ocp_t>::Lowest()))
<< "out[i-1]: " << out[i - 1];
#endif
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f8_ocp_t>(312.5f)))
<< "out[i-1]: " << out[i - 1];
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ float vec16_generator(ck::index_t i)
{
return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8);
}
__global__ void test_mx_fp8x16_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 16;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
f8x16_ocp_t fp8x16{};
float16_t float16{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float16[static_cast<int>(ii)] = vec16_generator(ii); });
fp8x16 = scaled_type_convert<ck::f8x16_ocp_t>(scale2, float16);
ck::static_for<0, N, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(fp8x16.AsType<f8_ocp_t>()(ck::Number<ii>{}));
});
}
TEST(MXFP8, DeviceF32x16ToF8x16ScaledConvert)
{
constexpr int N = 16;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x16_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__host__ __device__ float vec32_generator(ck::index_t i)
{
if(i < 16)
{
return vec16_generator(i % 16);
}
else
{
return 1.5f * vec16_generator(i % 16);
}
}
__global__ void test_mx_fp8x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(2.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
fp8x32 = mxf8_convert_rne<f8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}(
[&](auto ii) { p_test[i++] = type_convert<float>(fp8x32.AsType<f8_ocp_t>()(ii)); });
}
TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_fp8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(8.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}(
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
fp8x32 = mxf8_convert_sr<f8x32_ocp_t>(float32, type_convert<float>(scale2));
ck::static_for<0, N, 1>{}(
[&](auto ii) { p_test[i++] = type_convert<float>(fp8x32.AsType<f8_ocp_t>()(ii)); });
}
TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvertSR)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_fp8x32_device_scaled_convert_sr<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
auto scale2 = e8m0_bexp_t(4.0f);
f8x32_ocp_t fp8x32{};
float32_t float32{};
ck::static_for<0, N, 1>{}([&](auto ii) {
fp8x32.AsType<f8_ocp_t>()(ii) = type_convert<f8_ocp_t>(vec32_generator(ii) / 16.0f);
});
float32 = scaled_type_convert<float32_t>(scale2, fp8x32);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
}
TEST(MXFP8, DeviceF8x32ToF32x32ScaledConvert)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
auto i = 0;
ck::static_for<0, N, 1>{}([&](auto ii) {
EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl;
});
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......@@ -43,7 +43,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
return true;
}
}
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
// on gfx11 only support for 3d is implemented
......@@ -143,19 +142,23 @@ using KernelTypes2d = ::testing::Types<
std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>>;
using KernelTypes3d = ::testing::Types<
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<int8_t, int8_t, int8_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, NGCDHW, GKZYXC, NGKDHW, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NGCDHW, GKZYXC, NGKDHW, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
......@@ -179,6 +182,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this->conv_params.clear();
this->conv_params.push_back(
{2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment