Commit 222e9688 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2807c69e
...@@ -52,9 +52,9 @@ using DeviceGemmV2Instance = ...@@ -52,9 +52,9 @@ using DeviceGemmV2Instance =
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif #endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType, AccDataType,
......
...@@ -182,8 +182,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -182,8 +182,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
float ave_time = 0; float ave_time = 0;
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M, M,
...@@ -252,7 +251,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -252,7 +251,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N / (ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -13,20 +13,20 @@ namespace ck { ...@@ -13,20 +13,20 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
__device__ inline half4_t pki4_to_half4(int q)
__device__ inline half4_t pki4_to_half4(int q) { {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s. // Guarantee that the `(a & b) | c` operations are LOP3s.
//int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
//int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = (q & LO) | EX; int lo = (q & LO) | EX;
int hi = (q & HI) | EX; int hi = (q & HI) | EX;
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`. // directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8 const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; //1/16 const int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79 const int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res; vector_type<half_t, 4> res;
...@@ -34,9 +34,7 @@ __device__ inline half4_t pki4_to_half4(int q) { ...@@ -34,9 +34,7 @@ __device__ inline half4_t pki4_to_half4(int q) {
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16( res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
bit_cast<half2_t>(MUL),
bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
......
...@@ -398,7 +398,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -398,7 +398,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
#else #else
const index_t N0 = N / NPerBlock; const index_t N0 = N / NPerBlock;
const index_t N1 = NPerBlock; const index_t N1 = NPerBlock;
const auto b_grid_desc_n0_bk0_n1_bk1 = make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value)); const auto b_grid_desc_n0_bk0_n1_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk0_n1_bk1, b_grid_desc_n0_bk0_n1_bk1,
...@@ -1337,8 +1338,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1337,8 +1338,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
a_block_space_size_aligned * sizeof(ADataType) / APackedSize), sizeof(ADataType) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
...@@ -1023,7 +1023,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1023,7 +1023,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible"); "wrong! Not divisible");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)), "pk data N cannot be 1"); static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)),
"pk data N cannot be 1");
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
...@@ -1129,7 +1130,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1129,7 +1130,8 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(SrcBuffer::IsDynamicBuffer()) if constexpr(SrcBuffer::IsDynamicBuffer())
{ {
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize, is_src_valid); src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
} }
else if constexpr(SrcBuffer::IsStaticBuffer()) else if constexpr(SrcBuffer::IsStaticBuffer())
{ {
......
...@@ -55,7 +55,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -55,7 +55,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() { static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2; return 2;
...@@ -78,8 +77,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -78,8 +77,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op), src_element_op_(src_element_op),
dst_element_op_(dst_element_op) dst_element_op_(dst_element_op)
{ {
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, "SrcData != DstData"); static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)), "pk data N cannot be 1"); "SrcData != DstData");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> &&
(SrcScalarPerVector_ == 1 || DstScalarPerVector_ == 1)),
"pk data N cannot be 1");
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -185,9 +187,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -185,9 +187,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{}); [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use. // maintain a container record is_src_valid, waiting for RunWrite use.
//const bool is_src_valid = // const bool is_src_valid =
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
//src_oob_thread_scratch_tuple_(thread_scratch_id) // src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template SetAsType<bool>(src_data_idx_seq, is_src_valid); //.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
...@@ -203,12 +205,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -203,12 +205,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(decltype(src_element_op_)::is_pack8_invocable) if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector); return math::min(8, SrcScalarPerVector);
} }
else if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value) else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{ {
if constexpr(decltype(src_element_op_)::is_pack4_invocable) if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector); return math::min(4, SrcScalarPerVector);
} }
else if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value) else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{ {
if constexpr(decltype(src_element_op_)::is_pack2_invocable) if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector); return math::min(2, SrcScalarPerVector);
...@@ -226,8 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -226,8 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1"); static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1");
auto src_vector_container = auto src_vector_container = src_vector_type{
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed // apply the src elementwise op and convert to DstData under the hood if needed
...@@ -350,16 +354,18 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -350,16 +354,18 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id) auto op_r_v = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq); .template GetAsType<vector_t>(src_data_idx_seq);
//const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) // const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
//.template GetAsType<bool>(src_data_idx_seq); //.template GetAsType<bool>(src_data_idx_seq);
//auto op_r_v = is_src_valid ? op_r : vector_t(0); // auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v); .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
}); });
static_assert(!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim), "pk_i4_t does not support transpose"); static_assert(
!(is_same_v<pk_i4_t, remove_cvref_t<SrcData>> && SrcVectorDim != DstVectorDim),
"pk_i4_t does not support transpose");
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
...@@ -537,8 +543,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -537,8 +543,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto dst_data_idx_seq = generate_sequence_v2( constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{}); [&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
//const bool is_dst_valid = // const bool is_dst_valid =
//coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
...@@ -872,12 +878,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -872,12 +878,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
//using SrcOOBThreadScratch = // using SrcOOBThreadScratch =
//StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, // StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
//bool, // apply data_convert with SrcThreadScratch // bool, // apply data_convert with SrcThreadScratch
//1, // 1,
//decltype(src_oob_thread_scratch_desc_), // decltype(src_oob_thread_scratch_desc_),
//true>; // true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
...@@ -886,7 +892,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -886,7 +892,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>; true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
//StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_; // StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
......
...@@ -11,16 +11,15 @@ ...@@ -11,16 +11,15 @@
namespace ck { namespace ck {
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) { {
half2_t d; half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
return d; return d;
} }
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) { inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{
half2_t c; half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c; return c;
......
...@@ -1054,12 +1054,12 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type; ...@@ -1054,12 +1054,12 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8 // i8
//using uint8x2_t = typename vector_type<uint8_t, 2>::type; // using uint8x2_t = typename vector_type<uint8_t, 2>::type;
//using uint8x4_t = typename vector_type<uint8_t, 4>::type; // using uint8x4_t = typename vector_type<uint8_t, 4>::type;
//using uint8x8_t = typename vector_type<uint8_t, 8>::type; // using uint8x8_t = typename vector_type<uint8_t, 8>::type;
//using uint8x16_t = typename vector_type<uint8_t, 16>::type; // using uint8x16_t = typename vector_type<uint8_t, 16>::type;
//using uint8x32_t = typename vector_type<uint8_t, 32>::type; // using uint8x32_t = typename vector_type<uint8_t, 32>::type;
//using uint8x64_t = typename vector_type<uint8_t, 64>::type; // using uint8x64_t = typename vector_type<uint8_t, 64>::type;
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
...@@ -83,7 +83,7 @@ struct DynamicBuffer ...@@ -83,7 +83,7 @@ struct DynamicBuffer
return 1; return 1;
}(); }();
//static_assert(element_space_size_ % PackedSize == 0, ""); // static_assert(element_space_size_ % PackedSize == 0, "");
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
...@@ -97,7 +97,11 @@ struct DynamicBuffer ...@@ -97,7 +97,11 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_ / PackedSize, invalid_element_value_); p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
} }
} }
else else
......
...@@ -322,7 +322,8 @@ struct Tensor ...@@ -322,7 +322,8 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); } std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2; return mDesc.GetElementSpaceSize() / 2;
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment