Unverified Commit 3c5717df authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 171b9030 d9f1ead3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
namespace {
struct GroupedGemmKernelParam
{
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 32;
static const ck_tile::index_t M_Warp = 2;
static const ck_tile::index_t N_Warp = 2;
static const ck_tile::index_t K_Warp = 1;
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 8;
};
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemmKernelParam::M_Tile,
GroupedGemmKernelParam::N_Tile,
GroupedGemmKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<ALayout, BLayout, CLayout>>;
}; // namespace
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
}
template <typename ALayout, typename BLayout, typename CLayout>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_)
{
using GroupedGemmKernel = ::Kernel<ALayout, BLayout, CLayout>;
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(
p_workspace_,
arguments.data(),
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GroupedGemmKernelParam::kBlockPerCu>(
GroupedGemmKernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
return ave_time;
}
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using CDataType = ck_tile::half_t;
using AccDataType = float;
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
{
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
std::string op_name{"Grouped Gemm"};
std::size_t flop = 0, num_btype = 0;
for(int j = 0; j < group_count; ++j)
{
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
sizeof(BDataType) * args[j].K * args[j].N +
sizeof(CDataType) * args[j].M * args[j].N;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_grouped_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
invoke_gemm<ALayout, BLayout, CLayout>(warmup, repeat, group_count, gemm_descs);
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
if(arg_parser.get_int("validate"))
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
// else if(a_layout == "R" && b_layout == "R")
// {
// return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
set(TARGET_NAME tile_example_batched_transpose)
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
# Batched Transpose
This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# Make the transpose executable
make tile_example_batched_transpose -j
```
This will result in an executable `build/bin/tile_example_batched_transpose`
## example
```
args:
-N input batch size (default:2)
-C input channel size. (default:16)
-H input height size. (default:1)
-W input width size. (default:16)
-v whether do CPU validation or not (default: 1)
-layout_in input tensor data layout - NCHW by default
-layout_out output tensor data layout - NHWC by default
-seed seed to be used, -1 means random every time (default:-1)
-k_name t to 1 will print kernel name (default:0)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "batched_transpose_example.hpp"
#include <iostream>
template <typename ts_type,
ck_tile::index_t block_x,
ck_tile::index_t block_y,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y,
ck_tile::index_t thread_x,
ck_tile::index_t thread_y>
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
uint32_t dim_stride = a.height * a.width;
a.dim_stride = dim_stride;
a.dim_block_h = dim_block_h;
a.dim_block_w = dim_block_w;
using block_tile = ck_tile::sequence<block_x, block_y>;
using warp_tile = ck_tile::sequence<warp_x, warp_y>;
using thread_tile = ck_tile::sequence<thread_x, thread_y>;
using ts_problem =
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile>;
using ts_pipeline = ck_tile::BatchedTransposePipeline<ts_problem>;
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
#define FOREACH_TRANSPOSE_PARAM(F) \
F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \
F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \
F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \
F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1)
// Macro that defines one static function per line
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \
static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \
batched_transpose_kargs& a, ck_tile::stream_config& s) \
{ \
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY>(a, s); \
}
FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN)
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)
{
if(t.type == "fp16")
{
return transpose_fn_fp16_16_16_8_8_1_1(a, s);
}
else if(t.type == "bf16")
{
return transpose_fn_bf16_16_16_8_8_1_1(a, s);
}
else if(t.type == "fp32")
{
return transpose_fn_fp32_16_16_8_8_1_1(a, s);
}
else if(t.type == "int8")
{
return transpose_fn_int8_16_16_8_8_1_1(a, s);
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "batched_transpose_example.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 4);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto m =
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
std::cout << m;
if(v != len[3] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
}
}
std::cout << "]" << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("N", "2", "input batch size. ")
.insert("C", "16", "input channel size.")
.insert("H", "1", "input height size.")
.insert("W", "16", "input width size. ")
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename Type>
bool run_batched_transpose(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string prec = args.get_str("pr");
int N = args.get_int("N");
int C = args.get_int("C");
int H = args.get_int("H");
int W = args.get_int("W");
std::string layout_in = args.get_str("layout_in");
std::string layout_out = args.get_str("layout_out");
int seed = args.get_int("seed");
int dim_in[4], dim_out[4];
int stride_dim_in[4], stride_dim_out[4];
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
assert(nchw2nhwc != nhwc2nchw);
(void)nhwc2nchw;
dim_in[0] = N;
dim_in[1] = nchw2nhwc ? C : H;
dim_in[2] = nchw2nhwc ? H : W;
dim_in[3] = nchw2nhwc ? W : C;
dim_out[0] = N;
dim_out[1] = nchw2nhwc ? H : C;
dim_out[2] = nchw2nhwc ? W : H;
dim_out[3] = nchw2nhwc ? C : W;
stride_dim_in[0] = C * H * W;
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
stride_dim_in[2] = nchw2nhwc ? W : C;
stride_dim_in[3] = 1;
stride_dim_out[0] = C * H * W;
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
stride_dim_out[2] = nchw2nhwc ? C : W;
stride_dim_out[3] = 1;
if(seed < 0)
{
seed = std::time(nullptr);
}
ck_tile::HostTensor<Type> x_host(
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
ck_tile::HostTensor<Type> y_host(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
auto trait = batched_transpose_trait{prec, layout_in};
uint32_t height = nchw2nhwc ? C : H * W;
uint32_t width = nchw2nhwc ? H * W : C;
batched_transpose_kargs karg = [&]() {
batched_transpose_kargs a_;
a_.p_input = x_dev.GetDeviceBuffer();
a_.p_output = y_dev.GetDeviceBuffer();
a_.batch = N;
a_.height = height;
a_.width = width;
return a_;
}();
ck_tile::stream_config sc{nullptr, true};
auto ms = batched_transpose(trait, karg, sc);
std::size_t num_operations = N * C * H * (W - 1);
std::size_t num_bytes = N * C * H * W * sizeof(Type);
float ave_time = ms * 1E-3;
float gb_per_sec = num_bytes / ms * 1.E-6;
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
<< gb_per_sec << " GB/s, " << std::endl;
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
prec.c_str(),
N,
C,
H,
W,
layout_in.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
y_dev.FromDevice(y_host.data());
bool rtn = true;
if(validate)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile::HostTensor<Type> y_ref(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
auto [rtol, atol] = get_elimit<Type>("");
rtn &= ck_tile::check_err(
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string prec = args.get_str("pr");
bool r = true;
if(prec.compare("fp32") == 0)
{
r &= run_batched_transpose<float>(args);
}
else if(prec.compare("fp16") == 0)
{
r &= run_batched_transpose<ck_tile::fp16_t>(args);
}
else if(prec.compare("bf16") == 0)
{
r &= run_batched_transpose<ck_tile::bf16_t>(args);
}
else if(prec.compare("int8") == 0)
{
r &= run_batched_transpose<ck_tile::int8_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/batched_transpose.hpp"
#include <vector>
#include <string>
#pragma once
struct batched_transpose_trait
{
std::string type;
std::string layout;
};
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
{
};
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s);
#!/bin/sh
EXE=./build/bin/tile_example_batched_transpose
for pr in "fp32" "fp16" "int8" ; do
$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC'
done
...@@ -13,3 +13,8 @@ add_subdirectory(10_rmsnorm2d) ...@@ -13,3 +13,8 @@ add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant) add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting) add_subdirectory(13_moe_sorting)
add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(35_batched_transpose)
[Back to the main page](../../README.md)
# Composable Kernel supported operations
## Supported device operations
<!-- * [Average pooling](../../docs/markdown/tensor_operation/average_pooling.md) -->
<!-- * [Batched contraction](../../docs/markdown/tensor_operation/batched_contraction.md) -->
<!-- * [Batched gemm](../../docs/markdown/tensor_operation/batched_gemm.md) -->
<!-- * [Batchnorm](../../docs/markdown/tensor_operation/batchnorm.md) -->
<!-- * [CGEMM](../../docs/markdown/tensor_operation/cgemm.md) -->
<!-- * [Contraction](../../docs/markdown/tensor_operation/contraction.md) -->
<!-- * [Convolution](../../docs/markdown/tensor_operation/convolution.md) -->
<!-- * [Elementwise](../../docs/markdown/tensor_operation/elementwise.md) -->
* [GEMM](../../client_example/01_gemm/README.md)
* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md)
* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md)
* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md)
<!-- * [Grouped GEMM](../../docs/markdown/tensor_operation/grouped_gemm.md) -->
<!-- * [Image to Column and Column to Image](../../docs/markdown/tensor_operation/img2col.md) -->
<!-- * [Max pooling](../../docs/markdown/tensor_operation/max_pooling.md) -->
<!-- * [Reduce](../../docs/markdown/tensor_operation/reduce.md) -->
<!-- * [Normalization](../../docs/markdown/tensor_operation/normalization.md) -->
<!-- * [Permute](../../docs/markdown/tensor_operation/permute.md) -->
<!-- * [Put](../../docs/markdown/tensor_operation/put.md) -->
<!-- * [Softmax](../../docs/markdown/tensor_operation/softmax.md) -->
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/config.h" #include "ck/config.h"
#include "ck/utility/env.hpp" #include "ck/utility/env.hpp"
#ifndef CK_CODE_GEN_RTC
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
// environment variable to enable logging: // environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL // to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1 #define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter // constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...@@ -53,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -53,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// define general macros for various architectures // define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) defined(__gfx942__) || defined(__gfx950__)
#define __gfx9__ #define __gfx9__
#endif #endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
...@@ -155,9 +157,22 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -155,9 +157,22 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f6 conversions
#define CK_USE_SR_F6_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...@@ -230,13 +245,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -230,13 +245,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// denorm test fix, required to work around dissue // denorm test fix, necessary for gfx90a
#ifndef CK_WORKAROUND_DENORM_FIX #ifndef CK_GFX90A_DENORM_WORKAROUND
#define CK_WORKAROUND_DENORM_FIX 0 #define CK_GFX90A_DENORM_WORKAROUND 0
#endif // CK_GFX90A_DENORM_WORKAROUND
// Enable only for gfx90a
#if defined(__gfx90a__)
#if CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 1
#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1
#else #else
// enable only for gfx90a #define CK_GFX90A_DENORM_WORKAROUND 0
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // gfx90a
#endif // CK_WORKAROUND_DENORM_FIX
// set flag to 1 to build deprecated instances // set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1 #define CK_BUILD_DEPRECATED 1
......
...@@ -97,6 +97,10 @@ ...@@ -97,6 +97,10 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
#ifndef CK_ENABLE_DPP_KERNELS
#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
...@@ -111,6 +115,26 @@ ...@@ -111,6 +115,26 @@
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#endif #endif
#ifndef CK_USE_GFX94
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef CK_USE_OCP_FP8
#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@
#endif
#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on // clang-format on
#endif // CK_CONFIG_H_IN #endif // CK_CONFIG_H_IN
...@@ -55,20 +55,21 @@ inline bool is_xdl_supported() ...@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
{ {
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
} }
inline bool is_lds_direct_load_supported() inline bool is_lds_direct_load_supported()
{ {
// Check if direct loads from global memory to LDS are supported. // Check if direct loads from global memory to LDS are supported.
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" ||
ck::get_device_name() == "gfx950";
} }
inline bool is_bf16_atomic_supported() inline bool is_bf16_atomic_supported()
{ {
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
} }
inline bool is_gfx101_supported() inline bool is_gfx101_supported()
......
...@@ -26,6 +26,7 @@ namespace utils { ...@@ -26,6 +26,7 @@ namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1) double get_relative_threshold(const int number_of_accumulations = 1)
{ {
using F4 = ck::f4_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
...@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> || static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, int>, is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
...@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> || static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, int>, is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
...@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> || static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, int>, is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
...@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{ {
using F4 = ck::f4_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
...@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> || static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, int>, is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0; double compute_error = 0;
...@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> || static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, int>, is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> || if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
...@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> || static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, int>, is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> || if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
...@@ -450,5 +452,54 @@ check_err(const Range& out, ...@@ -450,5 +452,54 @@ check_err(const Range& out,
return res; return res;
} }
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 0.5,
double atol = 0.5)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
<< " number of errors: " << err_count << std::endl;
}
return res;
}
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment