Commit 1b616990 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8_uc

parents af30d6b6 800cf897
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
......@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
......@@ -41,49 +38,45 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......
......@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride")
......@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
......@@ -17,6 +37,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
......@@ -24,6 +45,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
......@@ -79,6 +101,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
......@@ -159,6 +182,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
......@@ -175,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
......@@ -236,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}
......@@ -256,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
// if(a_layout == "R" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
if(a_layout == "R" && b_layout == "C")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
......@@ -15,18 +15,14 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"
namespace {
struct GroupedGemmKernelParam
{
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static const bool kTilePermute = false;
static const ck_tile::index_t kOutputRank = 2;
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128;
......@@ -55,24 +51,6 @@ using CodegenGemmShape =
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename CLayout>
using GemmEpilogue = std::conditional_t<
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kTilePermute,
GroupedGemmKernelParam::kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN>>>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
......@@ -89,20 +67,32 @@ using CodegenPipelineProblem =
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>,
CodegenGemmPolicy>;
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<CLayout>>;
GemmEpilogue<ALayout, BLayout, CLayout>>;
}; // namespace
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
}
......
......@@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
......@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup,
int n_repeat,
......@@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup,
{
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args,
......@@ -108,16 +135,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
a_m_k_tensors.push_back(
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
......@@ -157,12 +184,23 @@ int run_grouped_gemm_example_with_layouts(int argc,
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
......@@ -188,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[])
{
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
// else if(a_layout == "R" && b_layout == "R")
// {
// return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename TLayout>
constexpr auto
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
constexpr auto
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
set(TARGET_NAME tile_example_batched_transpose)
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
# Batched Transpose
This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# Make the transpose executable
make tile_example_batched_transpose -j
```
This will result in an executable `build/bin/tile_example_batched_transpose`
## example
```
args:
-N input batch size (default:2)
-C input channel size. (default:16)
-H input height size. (default:1)
-W input width size. (default:16)
-v whether do CPU validation or not (default: 1)
-layout_in input tensor data layout - NCHW by default
-layout_out output tensor data layout - NHWC by default
-seed seed to be used, -1 means random every time (default:-1)
-k_name t to 1 will print kernel name (default:0)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "batched_transpose_example.hpp"
#include <iostream>
template <typename ts_type,
ck_tile::index_t block_x,
ck_tile::index_t block_y,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y,
ck_tile::index_t thread_x,
ck_tile::index_t thread_y>
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
uint32_t dim_stride = a.height * a.width;
a.dim_stride = dim_stride;
a.dim_block_h = dim_block_h;
a.dim_block_w = dim_block_w;
using block_tile = ck_tile::sequence<block_x, block_y>;
using warp_tile = ck_tile::sequence<warp_x, warp_y>;
using thread_tile = ck_tile::sequence<thread_x, thread_y>;
using ts_problem =
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile>;
using ts_pipeline = ck_tile::BatchedTransposePipeline<ts_problem>;
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
#define FOREACH_TRANSPOSE_PARAM(F) \
F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \
F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \
F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \
F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1)
// Macro that defines one static function per line
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \
static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \
batched_transpose_kargs& a, ck_tile::stream_config& s) \
{ \
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY>(a, s); \
}
FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN)
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)
{
if(t.type == "fp16")
{
return transpose_fn_fp16_16_16_8_8_1_1(a, s);
}
else if(t.type == "bf16")
{
return transpose_fn_bf16_16_16_8_8_1_1(a, s);
}
else if(t.type == "fp32")
{
return transpose_fn_fp32_16_16_8_8_1_1(a, s);
}
else if(t.type == "int8")
{
return transpose_fn_int8_16_16_8_8_1_1(a, s);
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "batched_transpose_example.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 4);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto m =
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
std::cout << m;
if(v != len[3] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
}
}
std::cout << "]" << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("N", "2", "input batch size. ")
.insert("C", "16", "input channel size.")
.insert("H", "1", "input height size.")
.insert("W", "16", "input width size. ")
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename Type>
bool run_batched_transpose(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string prec = args.get_str("pr");
int N = args.get_int("N");
int C = args.get_int("C");
int H = args.get_int("H");
int W = args.get_int("W");
std::string layout_in = args.get_str("layout_in");
std::string layout_out = args.get_str("layout_out");
int seed = args.get_int("seed");
int dim_in[4], dim_out[4];
int stride_dim_in[4], stride_dim_out[4];
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
assert(nchw2nhwc != nhwc2nchw);
(void)nhwc2nchw;
dim_in[0] = N;
dim_in[1] = nchw2nhwc ? C : H;
dim_in[2] = nchw2nhwc ? H : W;
dim_in[3] = nchw2nhwc ? W : C;
dim_out[0] = N;
dim_out[1] = nchw2nhwc ? H : C;
dim_out[2] = nchw2nhwc ? W : H;
dim_out[3] = nchw2nhwc ? C : W;
stride_dim_in[0] = C * H * W;
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
stride_dim_in[2] = nchw2nhwc ? W : C;
stride_dim_in[3] = 1;
stride_dim_out[0] = C * H * W;
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
stride_dim_out[2] = nchw2nhwc ? C : W;
stride_dim_out[3] = 1;
if(seed < 0)
{
seed = std::time(nullptr);
}
ck_tile::HostTensor<Type> x_host(
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
ck_tile::HostTensor<Type> y_host(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
auto trait = batched_transpose_trait{prec, layout_in};
uint32_t height = nchw2nhwc ? C : H * W;
uint32_t width = nchw2nhwc ? H * W : C;
batched_transpose_kargs karg = [&]() {
batched_transpose_kargs a_;
a_.p_input = x_dev.GetDeviceBuffer();
a_.p_output = y_dev.GetDeviceBuffer();
a_.batch = N;
a_.height = height;
a_.width = width;
return a_;
}();
ck_tile::stream_config sc{nullptr, true};
auto ms = batched_transpose(trait, karg, sc);
std::size_t num_operations = N * C * H * (W - 1);
std::size_t num_bytes = N * C * H * W * sizeof(Type);
float ave_time = ms * 1E-3;
float gb_per_sec = num_bytes / ms * 1.E-6;
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
<< gb_per_sec << " GB/s, " << std::endl;
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
prec.c_str(),
N,
C,
H,
W,
layout_in.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
y_dev.FromDevice(y_host.data());
bool rtn = true;
if(validate)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile::HostTensor<Type> y_ref(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
auto [rtol, atol] = get_elimit<Type>("");
rtn &= ck_tile::check_err(
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string prec = args.get_str("pr");
bool r = true;
if(prec.compare("fp32") == 0)
{
r &= run_batched_transpose<float>(args);
}
else if(prec.compare("fp16") == 0)
{
r &= run_batched_transpose<ck_tile::fp16_t>(args);
}
else if(prec.compare("bf16") == 0)
{
r &= run_batched_transpose<ck_tile::bf16_t>(args);
}
else if(prec.compare("int8") == 0)
{
r &= run_batched_transpose<ck_tile::int8_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/batched_transpose.hpp"
#include <vector>
#include <string>
#pragma once
struct batched_transpose_trait
{
std::string type;
std::string layout;
};
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
{
};
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s);
#!/bin/sh
EXE=./build/bin/tile_example_batched_transpose
for pr in "fp32" "fp16" "int8" ; do
$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC'
done
......@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(35_batched_transpose)
[Back to the main page](../../README.md)
# Composable Kernel supported operations
## Supported device operations
* [Average pooling]()
* [Batched contraction]()
* [Batched gemm]()
* [Batchnorm]()
* [CGEMM]()
* [Contraction]()
* [Convolution]()
* [Image to Column and Column to Image]()
* [Elementwise]()
* [GEMM]()
* [Max pooling]()
* [Reduce]()
* [Normalization]()
* [Permute]()
* [Put]()
* [Softmax]()
<!-- * [Average pooling](../../docs/markdown/tensor_operation/average_pooling.md) -->
<!-- * [Batched contraction](../../docs/markdown/tensor_operation/batched_contraction.md) -->
<!-- * [Batched gemm](../../docs/markdown/tensor_operation/batched_gemm.md) -->
<!-- * [Batchnorm](../../docs/markdown/tensor_operation/batchnorm.md) -->
<!-- * [CGEMM](../../docs/markdown/tensor_operation/cgemm.md) -->
<!-- * [Contraction](../../docs/markdown/tensor_operation/contraction.md) -->
<!-- * [Convolution](../../docs/markdown/tensor_operation/convolution.md) -->
<!-- * [Elementwise](../../docs/markdown/tensor_operation/elementwise.md) -->
* [GEMM](../../client_example/01_gemm/README.md)
* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md)
* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md)
* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md)
<!-- * [Grouped GEMM](../../docs/markdown/tensor_operation/grouped_gemm.md) -->
<!-- * [Image to Column and Column to Image](../../docs/markdown/tensor_operation/img2col.md) -->
<!-- * [Max pooling](../../docs/markdown/tensor_operation/max_pooling.md) -->
<!-- * [Reduce](../../docs/markdown/tensor_operation/reduce.md) -->
<!-- * [Normalization](../../docs/markdown/tensor_operation/normalization.md) -->
<!-- * [Permute](../../docs/markdown/tensor_operation/permute.md) -->
<!-- * [Put](../../docs/markdown/tensor_operation/put.md) -->
<!-- * [Softmax](../../docs/markdown/tensor_operation/softmax.md) -->
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/config.h"
#include "ck/utility/env.hpp"
#ifndef CK_CODE_GEN_RTC
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
......@@ -14,10 +14,12 @@
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
......@@ -53,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__) || defined(__gfx950__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
......@@ -155,9 +157,22 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f6 conversions
#define CK_USE_SR_F6_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......@@ -230,13 +245,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
// denorm test fix, necessary for gfx90a
#ifndef CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 0
#endif // CK_GFX90A_DENORM_WORKAROUND
// Enable only for gfx90a
#if defined(__gfx90a__)
#if CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 1
#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1
#else
// enable only for gfx90a
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
#define CK_GFX90A_DENORM_WORKAROUND 0
#endif // gfx90a
// set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1
......
......@@ -97,6 +97,10 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
#ifndef CK_ENABLE_DPP_KERNELS
#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@
#endif
//
// CK kernels which support XDL (MI series)
//
......@@ -127,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
......@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
{
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
}
inline bool is_lds_direct_load_supported()
{
// Check if direct loads from global memory to LDS are supported.
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" ||
ck::get_device_name() == "gfx950";
}
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
}
inline bool is_gfx101_supported()
......
......@@ -26,6 +26,7 @@ namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F4 = ck::f4_t;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
......@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>,
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
......@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
}
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>,
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
......@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>,
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
......@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F4 = ck::f4_t;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
......@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>,
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
......@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
}
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>,
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
......@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>,
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
......@@ -450,5 +452,54 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 0.5,
double atol = 0.5)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
<< " number of errors: " << err_count << std::endl;
}
return res;
}
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else
os << delim;
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
using RangeType = ck::remove_cvref_t<decltype(v)>;
if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
std::is_same_v<RangeType, ck::bhalf_t>)
{
os << ck::type_convert<float>(v);
}
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
<< vector_of_floats.template AsType<float>()[ck::Number<1>{}];
}
else
{
os << static_cast<T>(v);
......@@ -266,18 +275,18 @@ struct Tensor
using Data = std::vector<T>;
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
template <typename X, typename Y>
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
: mDesc(lens, strides), mData(GetElementSpaceSize())
{
}
template <typename Lengths>
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
......@@ -287,7 +296,7 @@ struct Tensor
{
}
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType() const
......@@ -322,7 +331,17 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else
{
return mDesc.GetElementSpaceSize();
}
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......@@ -469,29 +488,64 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
typename Data::iterator begin() { return mData.begin(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment