"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "546d736d66f48b2a2536d0bedd71823e88100b04"
Commit f0bbc5db authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

[CK TILE] GEMM with packed i4

parent 0e5e29c4
...@@ -3,3 +3,4 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) ...@@ -3,3 +3,4 @@ add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options(tile_example_gemm_universal PRIVATE target_compile_options(tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0 -mllvm -enable-noalias-to-md-conversion=0
) )
add_executable(tile_example_gemm_universal_pk_int4 EXCLUDE_FROM_ALL universal_gemm_pk_int4.cpp)
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#error "unsupported CK_TILE_PIPELINE_DEFAULT value" #error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif #endif
template <typename DataType> template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
template <> template <>
...@@ -75,6 +75,15 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t> ...@@ -75,6 +75,15 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t; using CDataType = ck_tile::half_t;
}; };
template <>
struct GemmBasicTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T> template <typename T>
struct DataTypeTraits; struct DataTypeTraits;
...@@ -114,6 +123,12 @@ struct DataTypeTraits<ck_tile::bf8_t> ...@@ -114,6 +123,12 @@ struct DataTypeTraits<ck_tile::bf8_t>
static constexpr const char* name = "bf8"; static constexpr const char* name = "bf8";
}; };
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
......
...@@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K, ...@@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
// Use higher threshold // Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
} }
template <typename Tensor>
void permute_tensor_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
...@@ -83,7 +137,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -83,7 +137,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time; return ave_time;
} }
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout> template <typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc, int run_gemm_example_with_layouts(int argc,
char* argv[], char* argv[],
const ALayout a_layout = ALayout{}, const ALayout a_layout = ALayout{},
...@@ -94,10 +153,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -94,10 +153,9 @@ int run_gemm_example_with_layouts(int argc,
if(!result) if(!result)
return -1; return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType; using AccDataType = typename GemmBasicTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType; constexpr ck_tile::index_t PackedSizeA = ck_tile::numeric_traits<ADataType>::PackedSize;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType; constexpr ck_tile::index_t PackedSizeB = ck_tile::numeric_traits<BDataType>::PackedSize;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
...@@ -107,10 +165,10 @@ int run_gemm_example_with_layouts(int argc, ...@@ -107,10 +165,10 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init"); ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
...@@ -123,16 +181,23 @@ int run_gemm_example_with_layouts(int argc, ...@@ -123,16 +181,23 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if (init_method == 0) { if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
} else if (init_method == 1) { }
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k); ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n); ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
} else if (init_method == 2) { }
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k); ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n); ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
} else { }
else
{
a_m_k.SetZero(); a_m_k.SetZero();
b_k_n.SetZero(); b_k_n.SetZero();
} }
...@@ -142,7 +207,17 @@ int run_gemm_example_with_layouts(int argc, ...@@ -142,7 +207,17 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data()); a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data()); if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
permute_tensor_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
...@@ -188,6 +263,11 @@ int run_gemm_example_with_layouts(int argc, ...@@ -188,6 +263,11 @@ int run_gemm_example_with_layouts(int argc,
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
...@@ -198,17 +278,17 @@ int run_gemm_example_with_layouts(int argc, ...@@ -198,17 +278,17 @@ int run_gemm_example_with_layouts(int argc,
BDataType* d_B; BDataType* d_B;
CDataType* d_C; CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType) / PackedSizeA));
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType) / PackedSizeB));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A, ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(), a_m_k_dev_buf.GetDeviceBuffer(),
M * K * sizeof(ADataType), M * K * sizeof(ADataType) / PackedSizeA,
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B, ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(),
N * K * sizeof(BDataType), N * K * sizeof(BDataType) / PackedSizeB,
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType, ck_tile::reference_gemm_gpu<ADataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
// ===============================================
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
// Tail number always Full - #PrefetchStages
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
return ave_time;
}
#include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
"wrong! not implemented"); "wrong! not implemented");
using rtn_type = thread_buffer<T, N>; using rtn_type = thread_buffer<T, N>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -156,7 +156,7 @@ struct vector_traits; ...@@ -156,7 +156,7 @@ struct vector_traits;
template <typename T, index_t N> template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>> struct vector_traits<thread_buffer<T, N>>
{ {
using scalar_type = T; using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N; static constexpr index_t vector_size = N;
}; };
......
...@@ -382,8 +382,9 @@ struct numeric_traits; ...@@ -382,8 +382,9 @@ struct numeric_traits;
template <> template <>
struct numeric_traits<bfloat16_t> struct numeric_traits<bfloat16_t>
{ {
static constexpr int exp = 8; static constexpr int exp = 8;
static constexpr int mant = 7; static constexpr int mant = 7;
static constexpr int PackedSize = 1;
}; };
#if CK_TILE_USE_CUSTOM_DATA_TYPE #if CK_TILE_USE_CUSTOM_DATA_TYPE
......
...@@ -225,6 +225,7 @@ struct numeric_traits<fp8_t> ...@@ -225,6 +225,7 @@ struct numeric_traits<fp8_t>
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ; static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
#endif #endif
static constexpr uint8_t abs_mask = 0x7F; static constexpr uint8_t abs_mask = 0x7F;
static constexpr int PackedSize = 1;
}; };
template <> template <>
...@@ -242,6 +243,7 @@ struct numeric_traits<bf8_t> ...@@ -242,6 +243,7 @@ struct numeric_traits<bf8_t>
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ; static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
#endif #endif
static constexpr uint8_t abs_mask = 0x7F; static constexpr uint8_t abs_mask = 0x7F;
static constexpr int PackedSize = 1;
}; };
// below is sw fp8 conversion, not utilizing hw instruction // below is sw fp8 conversion, not utilizing hw instruction
......
...@@ -241,6 +241,7 @@ struct numeric_traits<half_t> ...@@ -241,6 +241,7 @@ struct numeric_traits<half_t>
static constexpr uint16_t NegInf = 0xFC00; static constexpr uint16_t NegInf = 0xFC00;
static constexpr uint16_t NaN = 0x7C01; static constexpr uint16_t NaN = 0x7C01;
static constexpr uint16_t Neg0 = 0x8000; static constexpr uint16_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
...@@ -383,4 +384,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))) ...@@ -383,4 +384,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
CK_TILE_DEVICE CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); }; half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
#endif #endif
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
...@@ -91,6 +91,7 @@ struct numeric_traits<int8_t> ...@@ -91,6 +91,7 @@ struct numeric_traits<int8_t>
static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000; static constexpr uint32_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
#endif #endif
......
...@@ -94,6 +94,7 @@ struct numeric_traits<float> ...@@ -94,6 +94,7 @@ struct numeric_traits<float>
static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000; static constexpr uint32_t Neg0 = 0x80000000;
static constexpr int PackedSize = 1;
using bitwise_type = uint32_t; using bitwise_type = uint32_t;
}; };
......
...@@ -21,8 +21,8 @@ struct pk_int4_t ...@@ -21,8 +21,8 @@ struct pk_int4_t
{ {
using type = int8_t; using type = int8_t;
type data; type data;
__host__ __device__ constexpr pk_int4_t() : data{type{}} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {} CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
}; };
// limits // limits
...@@ -91,6 +91,19 @@ struct numeric<pk_int4_t> ...@@ -91,6 +91,19 @@ struct numeric<pk_int4_t>
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; } CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
}; };
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<pk_int4_t>
{
static constexpr int PackedSize = 2;
};
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{ {
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x); uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -34,7 +35,11 @@ template <typename T_, index_t N_> ...@@ -34,7 +35,11 @@ template <typename T_, index_t N_>
struct ext_vector struct ext_vector
{ {
static constexpr index_t N = N_; static constexpr index_t N = N_;
using value_type = typename native_t<remove_cvref_t<T_>>::type; // struct type is not supported for ext_vector
using value_type =
std::conditional_t<std::is_same_v<typename native_t<remove_cvref_t<T_>>::type, pk_int4_t>,
int8_t,
typename native_t<remove_cvref_t<T_>>::type>;
static_assert(!std::is_class_v<value_type>); static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
}; };
...@@ -58,7 +63,8 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type; ...@@ -58,7 +63,8 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
template <typename T> template <typename T>
struct vector_traits struct vector_traits
{ {
using scalar_type = remove_cvref_t<T>; using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
...@@ -66,7 +72,7 @@ struct vector_traits ...@@ -66,7 +72,7 @@ struct vector_traits
template <typename T, index_t N> template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))> struct vector_traits<T __attribute__((ext_vector_type(N)))>
{ {
using scalar_type = T; using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N; static constexpr index_t vector_size = N;
}; };
...@@ -200,21 +206,12 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); ...@@ -200,21 +206,12 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif #endif
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) // pk_int4_t
{ // using pk_int4_t
fp16x2_t vector_res; using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
vector_res.x = x.x + y.x; using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
vector_res.y = x.y + y.y; using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
return vector_res; using pk_int4x64_t = int8_t __attribute((ext_vector_type(64)));
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -231,6 +231,8 @@ struct buffer_view<address_space_enum::global, ...@@ -231,6 +231,8 @@ struct buffer_view<address_space_enum::global,
int32x4_t cached_buf_res_; int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0}; remove_cvref_t<T> invalid_element_value_ = T{0};
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
CK_TILE_HOST_DEVICE constexpr buffer_view() CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{ {
...@@ -255,7 +257,8 @@ struct buffer_view<address_space_enum::global, ...@@ -255,7 +257,8 @@ struct buffer_view<address_space_enum::global,
// Must call for buffers that need *_raw load/store // Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw() CK_TILE_HOST_DEVICE void init_raw()
{ {
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type)); cached_buf_res_ =
make_wave_buffer_resource(p_data_, (buffer_size_ / PackedSize) * sizeof(type));
} }
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
...@@ -307,7 +310,7 @@ struct buffer_view<address_space_enum::global, ...@@ -307,7 +310,7 @@ struct buffer_view<address_space_enum::global,
t_per_x, t_per_x,
Coherence, Coherence,
oob_conditional_check>( oob_conditional_check>(
p_data_, i + linear_offset, is_valid_element, buffer_size_); p_data_, i + linear_offset, is_valid_element, buffer_size_ / PackedSize);
} }
else else
{ {
...@@ -318,7 +321,7 @@ struct buffer_view<address_space_enum::global, ...@@ -318,7 +321,7 @@ struct buffer_view<address_space_enum::global,
oob_conditional_check>(p_data_, oob_conditional_check>(p_data_,
i + linear_offset, i + linear_offset,
is_valid_element, is_valid_element,
buffer_size_, buffer_size_ / PackedSize,
invalid_element_value_); invalid_element_value_);
} }
} }
...@@ -533,7 +536,7 @@ struct buffer_view<address_space_enum::global, ...@@ -533,7 +536,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>( amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
x, p_data_, i + linear_offset, is_valid_element, buffer_size_); x, p_data_, i + linear_offset, is_valid_element, buffer_size_ / PackedSize);
} }
else else
{ {
...@@ -569,7 +572,7 @@ struct buffer_view<address_space_enum::global, ...@@ -569,7 +572,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>( amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_); x, p_data_, i, linear_offset, is_valid_element, buffer_size_ / PackedSize);
} }
template <typename X, template <typename X,
...@@ -614,7 +617,7 @@ struct buffer_view<address_space_enum::global, ...@@ -614,7 +617,7 @@ struct buffer_view<address_space_enum::global,
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing)
{ {
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i + linear_offset, is_valid_element, buffer_size_); x, p_data_, i + linear_offset, is_valid_element, buffer_size_ / PackedSize);
} }
else else
{ {
...@@ -654,7 +657,7 @@ struct buffer_view<address_space_enum::global, ...@@ -654,7 +657,7 @@ struct buffer_view<address_space_enum::global,
Coherence, Coherence,
oob_conditional_check, oob_conditional_check,
pre_nop>( pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_); x, p_data_, i, linear_offset, is_valid_element, buffer_size_ / PackedSize);
} }
template <typename X, template <typename X,
...@@ -688,7 +691,7 @@ struct buffer_view<address_space_enum::global, ...@@ -688,7 +691,7 @@ struct buffer_view<address_space_enum::global,
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing)
{ {
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i + linear_offset, is_valid_element, buffer_size_); x, p_data_, i + linear_offset, is_valid_element, buffer_size_ / PackedSize);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
...@@ -897,83 +900,124 @@ struct buffer_view<address_space_enum::lds, ...@@ -897,83 +900,124 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
static_assert((std::is_same<remove_cvref_t<T>, int8_t>::value && static_assert(
std::is_same<remove_cvref_t<X>, int8_t>::value) || (std::is_same<remove_cvref_t<T>, int8_t>::value &&
(std::is_same<remove_cvref_t<T>, int8_t>::value && std::is_same<remove_cvref_t<X>, int8_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x2_t>::value) || (std::is_same<remove_cvref_t<T>, int8_t>::value &&
(std::is_same<remove_cvref_t<T>, int8_t>::value && std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x4_t>::value) || (std::is_same<remove_cvref_t<T>, int8_t>::value &&
(std::is_same<remove_cvref_t<T>, int8_t>::value && std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x8_t>::value) || (std::is_same<remove_cvref_t<T>, int8_t>::value &&
(std::is_same<remove_cvref_t<T>, int8_t>::value && std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x16_t>::value) || (std::is_same<remove_cvref_t<T>, int8_t>::value &&
(std::is_same<remove_cvref_t<T>, int8x4_t>::value && std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x4_t>::value) || (std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
(std::is_same<remove_cvref_t<T>, int8x8_t>::value && std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x8_t>::value) || (std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
(std::is_same<remove_cvref_t<T>, int8x16_t>::value && std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
std::is_same<remove_cvref_t<X>, int8x16_t>::value), (std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
"wrong! not implemented for this combination, please add " std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
"implementation"); // ext_vector_type for pk_int4 must use int8_t as type
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value && std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>::value) ||
std::is_same<remove_cvref_t<X>, int8_t>::value) (std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4x4_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4x8_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4x16_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>::value),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) = *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x); *c_style_pointer_cast<const int8_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value) std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 2>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) = *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x); *c_style_pointer_cast<const int16_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 4>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 8>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value) std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 16>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x); *c_style_pointer_cast<const int32x4_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8x4_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4x4_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 4>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8x8_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4x8_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 8>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(std::is_same<remove_cvref_t<T>, int8x16_t>::value && else if constexpr((std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value) std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(std::is_same<remove_cvref_t<T>, pk_int4x16_t>::value &&
std::is_same<remove_cvref_t<X>,
thread_buffer<pk_int4_t, 16>>::value))
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -27,6 +27,8 @@ struct static_distributed_tensor ...@@ -27,6 +27,8 @@ struct static_distributed_tensor
using ThreadTensorDesc = using ThreadTensorDesc =
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>; remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid"); static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
...@@ -59,7 +61,7 @@ struct static_distributed_tensor ...@@ -59,7 +61,7 @@ struct static_distributed_tensor
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size() CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
{ {
return kThreadElementSpaceSize; return kThreadElementSpaceSize / PackedSize;
} }
template <index_t... YSliceOrigins, index_t... YSliceLengths> template <index_t... YSliceOrigins, index_t... YSliceLengths>
...@@ -79,8 +81,9 @@ struct static_distributed_tensor ...@@ -79,8 +81,9 @@ struct static_distributed_tensor
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) { static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{}; constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) = sliced_thread_data(
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}]; number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
}); });
return sliced_thread_data; return sliced_thread_data;
...@@ -101,8 +104,9 @@ struct static_distributed_tensor ...@@ -101,8 +104,9 @@ struct static_distributed_tensor
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) { static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{}; constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) = thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}]; sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
PackedSize>{}];
}); });
} }
...@@ -115,7 +119,7 @@ struct static_distributed_tensor ...@@ -115,7 +119,7 @@ struct static_distributed_tensor
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{}); TileDistributedIndices{});
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}]; return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
} }
template <typename TileDistributedIndices> template <typename TileDistributedIndices>
...@@ -127,11 +131,11 @@ struct static_distributed_tensor ...@@ -127,11 +131,11 @@ struct static_distributed_tensor
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{}); TileDistributedIndices{});
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}); return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
} }
// //
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_; thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
}; };
template <typename DataType, typename StaticTileDistribution> template <typename DataType, typename StaticTileDistribution>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -45,6 +45,8 @@ struct tensor_view ...@@ -45,6 +45,8 @@ struct tensor_view
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>; using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
static constexpr auto DstInMemOp = DstInMemOp_; static constexpr auto DstInMemOp = DstInMemOp_;
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
CK_TILE_HOST_DEVICE constexpr tensor_view() = default; CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
...@@ -81,8 +83,8 @@ struct tensor_view ...@@ -81,8 +83,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
return buf_.template get<X>( return buf_.template get<X>(
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
...@@ -99,8 +101,8 @@ struct tensor_view ...@@ -99,8 +101,8 @@ struct tensor_view
bool is_valid_element, // flag bool is_valid_element, // flag
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
{ {
return buf_.template get<X>(coord.get_offset(), return buf_.template get<X>(coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
is_valid_element, is_valid_element,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
...@@ -122,8 +124,8 @@ struct tensor_view ...@@ -122,8 +124,8 @@ struct tensor_view
{ {
return buf_.template get_raw<X, oob_conditional_check, pre_nop>( return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, dst,
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -142,8 +144,12 @@ struct tensor_view ...@@ -142,8 +144,12 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}, bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
return buf_.template get_raw<X, oob_conditional_check, pre_nop>( return buf_.template get_raw<X, oob_conditional_check, pre_nop>(dst,
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{}); coord.get_offset() /
PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<pre_nop>{});
} }
template <typename X, template <typename X,
...@@ -159,8 +165,8 @@ struct tensor_view ...@@ -159,8 +165,8 @@ struct tensor_view
{ {
return buf_.template async_get<X>( return buf_.template async_get<X>(
smem, smem,
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
...@@ -178,8 +184,8 @@ struct tensor_view ...@@ -178,8 +184,8 @@ struct tensor_view
bool is_valid_element) const bool is_valid_element) const
{ {
return buf_.template async_get<X>(smem, return buf_.template async_get<X>(smem,
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
is_valid_element, is_valid_element,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{});
} }
...@@ -198,8 +204,8 @@ struct tensor_view ...@@ -198,8 +204,8 @@ struct tensor_view
{ {
return buf_.template async_get_raw<X>( return buf_.template async_get_raw<X>(
smem, smem,
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -217,8 +223,11 @@ struct tensor_view ...@@ -217,8 +223,11 @@ struct tensor_view
bool is_valid_element, bool is_valid_element,
bool_constant<pre_nop> = {}) const bool_constant<pre_nop> = {}) const
{ {
return buf_.template async_get_raw<X>( return buf_.template async_get_raw<X>(smem,
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{}); coord.get_offset() / PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<pre_nop>{});
} }
// X is vector of DataType. // X is vector of DataType.
...@@ -236,8 +245,8 @@ struct tensor_view ...@@ -236,8 +245,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
buf_.template set<X, oob_conditional_check>( buf_.template set<X, oob_conditional_check>(
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
...@@ -272,8 +281,8 @@ struct tensor_view ...@@ -272,8 +281,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
buf_.template set_raw<X, oob_conditional_check>( buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
...@@ -292,7 +301,7 @@ struct tensor_view ...@@ -292,7 +301,7 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
buf_.template set_raw<X, oob_conditional_check>( buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x); coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
} }
// X is vector of DataType. // X is vector of DataType.
...@@ -310,8 +319,8 @@ struct tensor_view ...@@ -310,8 +319,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
buf_.template update<DstInMemOp, X, oob_conditional_check>( buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
...@@ -330,7 +339,7 @@ struct tensor_view ...@@ -330,7 +339,7 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {})
{ {
buf_.template update<DstInMemOp, X, oob_conditional_check>( buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x); coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
} }
// X is vector of DataType. // X is vector of DataType.
...@@ -350,8 +359,8 @@ struct tensor_view ...@@ -350,8 +359,8 @@ struct tensor_view
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>( buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), coord.get_offset() / PackedSize,
linear_offset, linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x); x);
} }
...@@ -372,7 +381,7 @@ struct tensor_view ...@@ -372,7 +381,7 @@ struct tensor_view
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>( buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x); coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
} }
CK_TILE_HOST_DEVICE void print() const CK_TILE_HOST_DEVICE void print() const
......
...@@ -97,13 +97,15 @@ struct tile_window_with_static_distribution ...@@ -97,13 +97,15 @@ struct tile_window_with_static_distribution
} }
public: public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector = static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>(); get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>; // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type; // using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector>; using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private: private:
static constexpr auto scalars_per_access_ = [] { static constexpr auto scalars_per_access_ = [] {
...@@ -336,7 +338,7 @@ struct tile_window_with_static_distribution ...@@ -336,7 +338,7 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{}); bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1 #if 1
// write into distributed tensor // write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
...@@ -345,10 +347,11 @@ struct tile_window_with_static_distribution ...@@ -345,10 +347,11 @@ struct tile_window_with_static_distribution
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j]; vec_value.template get_as<DataType>()[j / Traits::PackedSize];
}); });
#else #else
constexpr index_t d = constexpr index_t d =
...@@ -390,8 +393,9 @@ struct tile_window_with_static_distribution ...@@ -390,8 +393,9 @@ struct tile_window_with_static_distribution
using SFC_Ys = typename Traits::SFC_Ys; using SFC_Ys = typename Traits::SFC_Ys;
static constexpr index_t YElementSize = static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % Traits::ScalarPerVector == 0); static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>; using vectorized_tbuf =
array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
// StaticBuffer<address_space_enum::vgpr, // StaticBuffer<address_space_enum::vgpr,
// vector_t, // vector_t,
// YElementSize / Traits::ScalarPerVector, // YElementSize / Traits::ScalarPerVector,
...@@ -419,7 +423,8 @@ struct tile_window_with_static_distribution ...@@ -419,7 +423,8 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr index_t d = constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
Traits::PackedSize;
static_assert(d % Traits::ScalarPerVector == 0); static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
...@@ -632,7 +637,7 @@ struct tile_window_with_static_distribution ...@@ -632,7 +637,7 @@ struct tile_window_with_static_distribution
// vector_type_t vec; // vector_type_t vec;
vector_t vec_value; vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
...@@ -641,9 +646,10 @@ struct tile_window_with_static_distribution ...@@ -641,9 +646,10 @@ struct tile_window_with_static_distribution
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -698,7 +704,7 @@ struct tile_window_with_static_distribution ...@@ -698,7 +704,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
...@@ -706,8 +712,9 @@ struct tile_window_with_static_distribution ...@@ -706,8 +712,9 @@ struct tile_window_with_static_distribution
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
vec_value.template get_as<DataType>()(j) = Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -759,7 +766,7 @@ struct tile_window_with_static_distribution ...@@ -759,7 +766,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
...@@ -768,9 +775,10 @@ struct tile_window_with_static_distribution ...@@ -768,9 +775,10 @@ struct tile_window_with_static_distribution
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -825,7 +833,7 @@ struct tile_window_with_static_distribution ...@@ -825,7 +833,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
...@@ -834,9 +842,10 @@ struct tile_window_with_static_distribution ...@@ -834,9 +842,10 @@ struct tile_window_with_static_distribution
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
...@@ -151,11 +151,13 @@ struct tile_window_linear ...@@ -151,11 +151,13 @@ struct tile_window_linear
} }
public: public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector = static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>(); get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t = thread_buffer<DataType, ScalarPerVector>; using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private: private:
static constexpr auto scalars_per_access_ = [] { static constexpr auto scalars_per_access_ = [] {
...@@ -498,17 +500,18 @@ struct tile_window_linear ...@@ -498,17 +500,18 @@ struct tile_window_linear
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor // write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j]; vec_value.template get_as<DataType>()[j / traits::PackedSize];
}); });
#else #else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
...@@ -556,17 +559,18 @@ struct tile_window_linear ...@@ -556,17 +559,18 @@ struct tile_window_linear
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor // write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j]; vec_value.template get_as<DataType>()[j / traits::PackedSize];
}); });
#else #else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
...@@ -595,8 +599,9 @@ struct tile_window_linear ...@@ -595,8 +599,9 @@ struct tile_window_linear
using SFC_Ys = typename traits::SFC_Ys; using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize = static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % traits::ScalarPerVector == 0); static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0);
using vectorized_tbuf = array<vector_t, YElementSize / traits::ScalarPerVector>; using vectorized_tbuf =
array<vector_t, YElementSize / (traits::PackedSize * traits::ScalarPerVector)>;
constexpr auto tile_dstr = TileDstr{}; constexpr auto tile_dstr = TileDstr{};
...@@ -620,7 +625,9 @@ struct tile_window_linear ...@@ -620,7 +625,9 @@ struct tile_window_linear
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
traits::PackedSize;
static_assert(d % traits::ScalarPerVector == 0); static_assert(d % traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
...@@ -804,16 +811,17 @@ struct tile_window_linear ...@@ -804,16 +811,17 @@ struct tile_window_linear
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -852,14 +860,15 @@ struct tile_window_linear ...@@ -852,14 +860,15 @@ struct tile_window_linear
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
vec_value.template get_as<DataType>()(j) = traits::PackedSize;
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -897,16 +906,17 @@ struct tile_window_linear ...@@ -897,16 +906,17 @@ struct tile_window_linear
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
...@@ -948,16 +958,17 @@ struct tile_window_linear ...@@ -948,16 +958,17 @@ struct tile_window_linear
// read from distributed tensor // read from distributed tensor
vector_t vec_value; vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple( constexpr auto idx_ys = generate_tuple(
[&](auto jj) { [&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
}, },
number<NDimY>{}); number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j) = vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>(); dstr_tensor.get_thread_buffer().template at<d>();
}); });
......
...@@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value, static_assert(
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value) if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
{ {
return 0; return 0;
} }
...@@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value) if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
{ {
return 0; return 0;
} }
...@@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value) if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
{ {
return 0; return 0;
} }
...@@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value, static_assert(
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0; double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value) if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
{ {
return 0; return 0;
} }
...@@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value) if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
{ {
return 0; return 0;
} }
...@@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value) if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
{ {
return 0; return 0;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -282,7 +282,14 @@ struct FillMonotonicSeq ...@@ -282,7 +282,14 @@ struct FillMonotonicSeq
{ {
std::generate(first, last, [=, n = init_value_]() mutable { std::generate(first, last, [=, n = init_value_]() mutable {
auto tmp = n; auto tmp = n;
n += step_; if constexpr(std::is_same_v<decltype(tmp), pk_int4_t>)
{
n.data += step_.data;
}
else
{
n += step_;
}
return tmp; return tmp;
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment