Commit 222e9688 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2807c69e
......@@ -52,15 +52,15 @@ using DeviceGemmV2Instance =
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif
// clang-format on
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
#include "run_gemm_example_v2.inc"
......
......@@ -182,20 +182,19 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -207,42 +206,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
}
if(config.time_kernel)
......@@ -252,7 +251,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N / (ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) + sizeof(CDataType) * M * N;
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
......@@ -13,31 +13,29 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
__device__ inline half4_t pki4_to_half4(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
//int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
//int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = (q & LO) | EX;
int hi = (q & HI) | EX;
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; //1/16
const int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi),
bit_cast<half2_t>(MUL),
bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}];
__device__ inline half4_t pki4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = (q & LO) | EX;
int hi = (q & HI) | EX;
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}];
}
struct PassThroughPack8
......@@ -46,14 +44,14 @@ struct PassThroughPack8
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
{
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
}
}
constexpr const static bool is_pack8_invocable = true;
};
......@@ -70,21 +68,21 @@ struct PassThroughPack2
}
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
y = {l_f16, h_f16};
y = {l_f16, h_f16};
#else
uint32_t t = ck::bit_cast<uint8_t>(x);
y = ck::bit_cast<half2_t>(t);
y = ck::bit_cast<half2_t>(t);
#endif
}
}
constexpr const static bool is_pack2_invocable = true;
};
......
......@@ -398,7 +398,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
#else
const index_t N0 = N / NPerBlock;
const index_t N1 = NPerBlock;
const auto b_grid_desc_n0_bk0_n1_bk1 = make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
const auto b_grid_desc_n0_bk0_n1_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk0_n1_bk1,
......@@ -653,8 +654,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
......@@ -788,8 +789,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
......@@ -1337,8 +1338,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......@@ -1354,19 +1356,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
......@@ -1732,7 +1734,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)),
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......@@ -1740,7 +1742,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)),
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
......
......@@ -1023,7 +1023,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)), "pk data N cannot be 1");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)),
"pk data N cannot be 1");
}
template <typename SrcRefToOriginDisplacement,
......@@ -1129,7 +1130,8 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize, is_src_valid);
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
......@@ -1171,8 +1173,8 @@ struct ThreadwiseTensorSliceTransfer_v4
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
......
......@@ -55,7 +55,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
......@@ -63,8 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc,
......@@ -78,8 +77,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, "SrcData != DstData");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)), "pk data N cannot be 1");
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> &&
(SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)),
"pk data N cannot be 1");
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -185,10 +187,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
//const bool is_src_valid =
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
//src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
......@@ -203,12 +205,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
......@@ -226,8 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1");
auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
......@@ -348,18 +352,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
.template GetAsType<vector_t>(src_data_idx_seq);
//const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template GetAsType<bool>(src_data_idx_seq);
// const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template GetAsType<bool>(src_data_idx_seq);
//auto op_r_v = is_src_valid ? op_r : vector_t(0);
// auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
});
static_assert(!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim), "pk_i4_t does not support transpose");
static_assert(
!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim),
"pk_i4_t does not support transpose");
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
......@@ -432,9 +438,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
else
{
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
......@@ -537,8 +543,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
//const bool is_dst_valid =
//coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// const bool is_dst_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
......@@ -556,9 +562,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize,
true,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_coord_.GetOffset() / PackedSize,
true,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
......@@ -872,12 +878,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_),
true>;
//using SrcOOBThreadScratch =
//StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
//bool, // apply data_convert with SrcThreadScratch
//1,
//decltype(src_oob_thread_scratch_desc_),
//true>;
// using SrcOOBThreadScratch =
// StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
// bool, // apply data_convert with SrcThreadScratch
// 1,
// decltype(src_oob_thread_scratch_desc_),
// true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
......@@ -886,7 +892,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
//StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
// StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
......
......@@ -11,19 +11,18 @@
namespace ck {
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) {
half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
return d;
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
{
half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d;
}
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) {
half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{
half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// c0 += inner_product(a, b0)
......
......@@ -1054,12 +1054,12 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8
// i8
//using uint8x2_t = typename vector_type<uint8_t, 2>::type;
//using uint8x4_t = typename vector_type<uint8_t, 4>::type;
//using uint8x8_t = typename vector_type<uint8_t, 8>::type;
//using uint8x16_t = typename vector_type<uint8_t, 16>::type;
//using uint8x32_t = typename vector_type<uint8_t, 32>::type;
//using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
// using uint8x4_t = typename vector_type<uint8_t, 4>::type;
// using uint8x8_t = typename vector_type<uint8_t, 8>::type;
// using uint8x16_t = typename vector_type<uint8_t, 16>::type;
// using uint8x32_t = typename vector_type<uint8_t, 32>::type;
// using uint8x64_t = typename vector_type<uint8_t, 64>::type;
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
......@@ -83,7 +83,7 @@ struct DynamicBuffer
return 1;
}();
//static_assert(element_space_size_ % PackedSize == 0, "");
// static_assert(element_space_size_ % PackedSize == 0, "");
if constexpr(InvalidElementUseNumericalZeroValue)
{
......@@ -97,7 +97,11 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_ / PackedSize, invalid_element_value_);
p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
}
}
else
......
......@@ -86,8 +86,8 @@ struct ReferenceGemm : public device::BaseOperator
}
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
......
......@@ -322,11 +322,12 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const {
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2;
return mDesc.GetElementSpaceSize() / 2;
else
return mDesc.GetElementSpaceSize();
return mDesc.GetElementSpaceSize();
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......
......@@ -89,9 +89,9 @@ struct GeneratorTensor_1<ck::pk_i4_t>
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
return r;
}
};
......@@ -144,8 +144,8 @@ struct GeneratorTensor_2<ck::pk_i4_t>
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment