Commit cc485d80 authored by mtgu0705's avatar mtgu0705
Browse files

pk_i4_t enabled based on commit "remove gfx12 targets from daily builds with rocm6.2 (#1560)"

parent cfac9497
...@@ -31,8 +31,8 @@ template <typename SliceLengths, ...@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename DstDimAccessOrder, typename DstDimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector_,
index_t DstScalarPerVector, index_t DstScalarPerVector_,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
...@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op), src_element_op_(src_element_op),
dst_element_op_(dst_element_op) dst_element_op_(dst_element_op)
{ {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
...@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v; dst_vector_type op_r_v;
...@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(decltype(src_element_op_)::is_pack8_invocable) if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector); return math::min(8, SrcScalarPerVector);
} }
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value) else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{ {
if constexpr(decltype(src_element_op_)::is_pack4_invocable) if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector); return math::min(4, SrcScalarPerVector);
} }
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value) else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{ {
if constexpr(decltype(src_element_op_)::is_pack2_invocable) if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector); return math::min(2, SrcScalarPerVector);
} }
return 1; else
{
return 1;
}
}; };
constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
...@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type; using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type; using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed // apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx), src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
...@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
}); });
#else #else
// OOB Check // OOB Check
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value && (is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
// each transpose does // each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_ // DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
else else
{ {
static_ford<SliceLengths>{}([&](auto idx) { constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
}); });
} }
...@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim // src scalar per access on each dim
// TODO: don't use this // TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
...@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation // apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]); dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
dst_vector_container.template AsType<DstData>()(i) = dst_v;
}); });
// copy data from dst_vector_container to dst_buf // copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset() / PackedSize,
is_dst_valid, is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]); dst_vector_container.template AsType<dst_vector_t>()[I0]);
...@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
...@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetSrcThreadScratchDescriptor() __device__ static constexpr auto GetSrcThreadScratchDescriptor()
{ {
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{ {
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
// 1st stage of transforms // 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector_,
index_t DstScalarPerVector_,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
else
{
return 1;
}
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<dst_vector_t>(src_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
template <index_t ThreadScratchId>
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
#else
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<bool>(src_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
});
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert(SrcVectorDim != DstVectorDim, "wrong");
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
}
#endif
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// if there is transpose, it's done here
// if there is oob check, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
DstData dst_v;
// apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize,
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
return make_naive_tensor_descriptor_packed(src_access_lengths);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, // apply data_convert with SrcThreadScratch
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool, // apply data_convert with SrcThreadScratch
1,
decltype(src_oob_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(dst_thread_scratch_desc_),
true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
};
} // namespace ck
...@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
using r_t = typename vector_type<T, N>::type; using r_t = typename vector_type<T, N>::type;
......
...@@ -13,6 +13,15 @@ using int4_t = _BitInt(4); ...@@ -13,6 +13,15 @@ using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
// custom data type - pack int4 data
struct pk_i4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
...@@ -149,6 +158,13 @@ struct scalar_type<int4_t> ...@@ -149,6 +158,13 @@ struct scalar_type<int4_t>
}; };
#endif #endif
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
...@@ -990,6 +1006,20 @@ struct vector_type<T, 256> ...@@ -990,6 +1006,20 @@ struct vector_type<T, 256>
} }
}; };
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
static constexpr index_t vector_size = N;
};
using int64_t = long; using int64_t = long;
// fp64 // fp64
...@@ -1060,6 +1090,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type; ...@@ -1060,6 +1090,11 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type; using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type; using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
...@@ -324,6 +324,18 @@ struct Tensor ...@@ -324,6 +324,18 @@ struct Tensor
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
else
{
return mDesc.GetElementSpaceSize();
}
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
void SetZero() { ck::ranges::fill<T>(mData, 0); } void SetZero() { ck::ranges::fill<T>(mData, 0); }
...@@ -469,29 +481,64 @@ struct Tensor ...@@ -469,29 +481,64 @@ struct Tensor
template <typename... Is> template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
{ {
return mDesc.GetOffsetFromMultiIndex(is...); if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
} }
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
} }
template <typename... Is> template <typename... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
} }
T& operator()(std::vector<std::size_t> idx) T& operator()(std::vector<std::size_t> idx)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
} }
const T& operator()(std::vector<std::size_t> idx) const const T& operator()(std::vector<std::size_t> idx) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
} }
typename Data::iterator begin() { return mData.begin(); } typename Data::iterator begin() { return mData.begin(); }
......
...@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t> ...@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
} }
}; };
template <>
struct GeneratorTensor_1<ck::pk_i4_t>
{
int8_t value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
}
};
template <typename T> template <typename T>
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
...@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t> ...@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
template <>
struct GeneratorTensor_2<ck::pk_i4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
};
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
template <> template <>
struct GeneratorTensor_2<ck::f8_t> struct GeneratorTensor_2<ck::f8_t>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment