Commit 222e9688 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2807c69e
...@@ -52,15 +52,15 @@ using DeviceGemmV2Instance = ...@@ -52,15 +52,15 @@ using DeviceGemmV2Instance =
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif #endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>; PassThrough>;
#include "run_gemm_example_v2.inc" #include "run_gemm_example_v2.inc"
......
...@@ -182,20 +182,19 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -182,20 +182,19 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
float ave_time = 0; float ave_time = 0;
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), M,
M, N,
N, K,
K, StrideA,
StrideA, StrideB,
StrideB, StrideC,
StrideC, KBatch,
KBatch, a_element_op,
a_element_op, b_element_op,
b_element_op, c_element_op);
c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -207,42 +206,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -207,42 +206,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result, pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result, c_m_n_host_result,
"Error: Incorrect results!", "Error: Incorrect results!",
get_rtol<CDataType>(), get_rtol<CDataType>(),
get_atol<CDataType>()); get_atol<CDataType>());
std::cout << "c_m_n_device_result: " << std::endl; std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++) for(int i = 0; i < M; i++)
{ {
for(int j = 0; j < N; j++) for(int j = 0; j < N; j++)
{ {
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ","; std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
} }
std::cout << std::endl; std::cout << std::endl;
} }
std::cout << "c_m_n_host_result: " << std::endl; std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++) for(int i = 0; i < M; i++)
{ {
for(int j = 0; j < N; j++) for(int j = 0; j < N; j++)
{ {
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ","; std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
} }
std::cout << std::endl; std::cout << std::endl;
} }
} }
if(config.time_kernel) if(config.time_kernel)
...@@ -252,7 +251,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -252,7 +251,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N / (ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -13,31 +13,29 @@ namespace ck { ...@@ -13,31 +13,29 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
__device__ inline half4_t pki4_to_half4(int q)
__device__ inline half4_t pki4_to_half4(int q) { {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
//int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
//int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = (q & LO) | EX; int lo = (q & LO) | EX;
int hi = (q & HI) | EX; int hi = (q & HI) | EX;
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`. // directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8 const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; //1/16 const int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79 const int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res; vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) = res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16( res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
bit_cast<half2_t>(MUL), return res.template AsType<half4_t>()[Number<0>{}];
bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}];
} }
struct PassThroughPack8 struct PassThroughPack8
...@@ -46,14 +44,14 @@ struct PassThroughPack8 ...@@ -46,14 +44,14 @@ struct PassThroughPack8
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{ {
vector_type<half_t, 8> result; vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x)); result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8); result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}]; y = result.template AsType<half8_t>()[Number<0>{}];
} }
constexpr const static bool is_pack8_invocable = true; constexpr const static bool is_pack8_invocable = true;
}; };
...@@ -70,21 +68,21 @@ struct PassThroughPack2 ...@@ -70,21 +68,21 @@ struct PassThroughPack2
} }
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{ {
#if 1 #if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(x); uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0; uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4; uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l); auto l_f16 = ck::type_convert<ck::half_t>(x_l);
auto h_f16 = ck::type_convert<ck::half_t>(x_h); auto h_f16 = ck::type_convert<ck::half_t>(x_h);
y = {l_f16, h_f16}; y = {l_f16, h_f16};
#else #else
uint32_t t = ck::bit_cast<uint8_t>(x); uint32_t t = ck::bit_cast<uint8_t>(x);
y = ck::bit_cast<half2_t>(t); y = ck::bit_cast<half2_t>(t);
#endif #endif
} }
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
......
...@@ -398,7 +398,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -398,7 +398,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
#else #else
const index_t N0 = N / NPerBlock; const index_t N0 = N / NPerBlock;
const index_t N1 = NPerBlock; const index_t N1 = NPerBlock;
const auto b_grid_desc_n0_bk0_n1_bk1 = make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value)); const auto b_grid_desc_n0_bk0_n1_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk0_n1_bk1, b_grid_desc_n0_bk0_n1_bk1,
...@@ -653,8 +654,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -653,8 +654,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// in some cases. // in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number), AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
...@@ -788,8 +789,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -788,8 +789,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
// NLdsLayer * K0 as logical Bank // NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple( make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number), BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
...@@ -1337,8 +1338,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1337,8 +1338,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
a_block_space_size_aligned * sizeof(ADataType) / APackedSize), sizeof(ADataType) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
...@@ -1354,19 +1356,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1354,19 +1356,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock); KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1, blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
...@@ -1732,7 +1734,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1732,7 +1734,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(static_cast<char*>(p_shared_0) + static_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)), a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
...@@ -1740,7 +1742,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1740,7 +1742,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) + bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)), a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
......
...@@ -1023,7 +1023,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1023,7 +1023,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible"); "wrong! Not divisible");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)), "pk data N cannot be 1"); static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)),
"pk data N cannot be 1");
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
...@@ -1129,7 +1130,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1129,7 +1130,8 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(SrcBuffer::IsDynamicBuffer()) if constexpr(SrcBuffer::IsDynamicBuffer())
{ {
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize, is_src_valid); src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
} }
else if constexpr(SrcBuffer::IsStaticBuffer()) else if constexpr(SrcBuffer::IsStaticBuffer())
{ {
...@@ -1171,8 +1173,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1171,8 +1173,8 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value && else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value && is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0) SrcScalarPerVector % 2 == 0)
{ {
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
......
...@@ -55,7 +55,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -55,7 +55,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() { static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2; return 2;
...@@ -63,8 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -63,8 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return 1; return 1;
}(); }();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{}; static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{}; static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
...@@ -78,8 +77,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -78,8 +77,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op), src_element_op_(src_element_op),
dst_element_op_(dst_element_op) dst_element_op_(dst_element_op)
{ {
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, "SrcData != DstData"); static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)), "pk data N cannot be 1"); "SrcData != DstData");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> &&
(SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)),
"pk data N cannot be 1");
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -185,10 +187,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -185,10 +187,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{}); [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use. // maintain a container record is_src_valid, waiting for RunWrite use.
//const bool is_src_valid = // const bool is_src_valid =
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
//src_oob_thread_scratch_tuple_(thread_scratch_id) // src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template SetAsType<bool>(src_data_idx_seq, is_src_valid); //.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -203,12 +205,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -203,12 +205,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(decltype(src_element_op_)::is_pack8_invocable) if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector); return math::min(8, SrcScalarPerVector);
} }
else if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value) else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{ {
if constexpr(decltype(src_element_op_)::is_pack4_invocable) if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector); return math::min(4, SrcScalarPerVector);
} }
else if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value) else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{ {
if constexpr(decltype(src_element_op_)::is_pack2_invocable) if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector); return math::min(2, SrcScalarPerVector);
...@@ -226,8 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -226,8 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1"); static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1");
auto src_vector_container = auto src_vector_container = src_vector_type{
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed // apply the src elementwise op and convert to DstData under the hood if needed
...@@ -348,18 +352,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -348,18 +352,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type; using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id) auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq); .template GetAsType<vector_t>(src_data_idx_seq);
//const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) // const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template GetAsType<bool>(src_data_idx_seq); //.template GetAsType<bool>(src_data_idx_seq);
//auto op_r_v = is_src_valid ? op_r : vector_t(0); // auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v); .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
}); });
static_assert(!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim), "pk_i4_t does not support transpose"); static_assert(
!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim),
"pk_i4_t does not support transpose");
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
...@@ -432,9 +438,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -432,9 +438,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
else else
{ {
constexpr auto packed_per_access = generate_sequence( constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) { static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
...@@ -537,8 +543,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -537,8 +543,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto dst_data_idx_seq = generate_sequence_v2( constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{}); [&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
//const bool is_dst_valid = // const bool is_dst_valid =
//coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
...@@ -556,9 +562,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -556,9 +562,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// copy data from dst_vector_container to dst_buf // copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize, dst_coord_.GetOffset() / PackedSize,
true, true,
dst_vector_container.template AsType<dst_vector_t>()[I0]); dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -872,12 +878,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -872,12 +878,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
//using SrcOOBThreadScratch = // using SrcOOBThreadScratch =
//StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, // StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
//bool, // apply data_convert with SrcThreadScratch // bool, // apply data_convert with SrcThreadScratch
//1, // 1,
//decltype(src_oob_thread_scratch_desc_), // decltype(src_oob_thread_scratch_desc_),
//true>; // true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
...@@ -886,7 +892,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -886,7 +892,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>; true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
//StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_; // StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
......
...@@ -11,19 +11,18 @@ ...@@ -11,19 +11,18 @@
namespace ck { namespace ck {
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) { {
half2_t d; half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
: "=v"(d) return d;
: "v"(a), "v"(b), "v"(c));
return d;
} }
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) { inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
half2_t c; {
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); half2_t c;
return c; asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
} }
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
......
...@@ -1054,12 +1054,12 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type; ...@@ -1054,12 +1054,12 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8 // i8
//using uint8x2_t = typename vector_type<uint8_t, 2>::type; // using uint8x2_t = typename vector_type<uint8_t, 2>::type;
//using uint8x4_t = typename vector_type<uint8_t, 4>::type; // using uint8x4_t = typename vector_type<uint8_t, 4>::type;
//using uint8x8_t = typename vector_type<uint8_t, 8>::type; // using uint8x8_t = typename vector_type<uint8_t, 8>::type;
//using uint8x16_t = typename vector_type<uint8_t, 16>::type; // using uint8x16_t = typename vector_type<uint8_t, 16>::type;
//using uint8x32_t = typename vector_type<uint8_t, 32>::type; // using uint8x32_t = typename vector_type<uint8_t, 32>::type;
//using uint8x64_t = typename vector_type<uint8_t, 64>::type; // using uint8x64_t = typename vector_type<uint8_t, 64>::type;
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
...@@ -83,7 +83,7 @@ struct DynamicBuffer ...@@ -83,7 +83,7 @@ struct DynamicBuffer
return 1; return 1;
}(); }();
//static_assert(element_space_size_ % PackedSize == 0, ""); // static_assert(element_space_size_ % PackedSize == 0, "");
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
...@@ -97,7 +97,11 @@ struct DynamicBuffer ...@@ -97,7 +97,11 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_ / PackedSize, invalid_element_value_); p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
} }
} }
else else
......
...@@ -86,8 +86,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -86,8 +86,8 @@ struct ReferenceGemm : public device::BaseOperator
} }
else if constexpr(is_same_v<BDataType, pk_i4_t>) else if constexpr(is_same_v<BDataType, pk_i4_t>)
{ {
pk_i4_t i4x2 = arg.b_k_n_(k, n); pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
else else
......
...@@ -322,11 +322,12 @@ struct Tensor ...@@ -322,11 +322,12 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); } std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2; return mDesc.GetElementSpaceSize() / 2;
else else
return mDesc.GetElementSpaceSize(); return mDesc.GetElementSpaceSize();
} }
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......
...@@ -89,9 +89,9 @@ struct GeneratorTensor_1<ck::pk_i4_t> ...@@ -89,9 +89,9 @@ struct GeneratorTensor_1<ck::pk_i4_t>
template <typename... Is> template <typename... Is>
ck::pk_i4_t operator()(Is...) ck::pk_i4_t operator()(Is...)
{ {
int t = value + 8; int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff; ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r; return r;
} }
}; };
...@@ -144,8 +144,8 @@ struct GeneratorTensor_2<ck::pk_i4_t> ...@@ -144,8 +144,8 @@ struct GeneratorTensor_2<ck::pk_i4_t>
template <typename... Is> template <typename... Is>
ck::pk_i4_t operator()(Is...) ck::pk_i4_t operator()(Is...)
{ {
int hi = std::rand() % (max_value - min_value) + min_value + 8; int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8; int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff; ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r; return r;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment