Commit bf75259f authored by aska-0096's avatar aska-0096
Browse files

New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm

parent 061009a3
...@@ -57,7 +57,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -57,7 +57,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto scale_thread_slice_lengths = BlockScaleSliceLengths{} / ThreadClusterLengths{}; static constexpr auto scale_thread_slice_lengths =
BlockScaleSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -92,7 +93,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -92,7 +93,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} && is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
is_same<BlockScaleSliceLengths, decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{} , is_same<BlockScaleSliceLengths,
decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
...@@ -108,8 +110,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -108,8 +110,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin); src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetScaleSliceOrigin(scale_desc, threadwise_transfer_.SetScaleSliceOrigin(
scale_block_slice_origin + thread_data_idx_begin); scale_desc, scale_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc, threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin); dst_block_slice_origin + thread_data_idx_begin);
} }
...@@ -129,8 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -129,8 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
// With the assumption, scale scratch is always one // With the assumption, scale scratch is always one
template <typename ScaleBuffer> template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
const ScaleBuffer& scale_buf)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
......
...@@ -677,7 +677,6 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -677,7 +677,6 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std::map<PipelineVersion, std::string> PipelineVersionToString{ std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}, {PipelineVersion::v2, "v2"},
{PipelineVersion::dequant_v1, "dequant_v1"},
{PipelineVersion::weight_only, "weight_only"}}; {PipelineVersion::weight_only, "weight_only"}};
// clang-format off // clang-format off
......
...@@ -404,18 +404,11 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> ...@@ -404,18 +404,11 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
// static constexpr ck::half_t fp16_subtract = -1152;
// Output.template AsType<ck::half_t>()(Number<0>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<1>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<2>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<3>{}) += fp16_subtract;
// inline assembly get very poor performance as no chance to global scheduling
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0]) : "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1]) : "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
......
...@@ -12,7 +12,6 @@ enum struct PipelineVersion ...@@ -12,7 +12,6 @@ enum struct PipelineVersion
{ {
v1, v1,
v2, v2,
dequant_v1,
weight_only, weight_only,
}; };
...@@ -38,10 +37,6 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -38,10 +37,6 @@ constexpr auto GridwiseGemmPipeline_Selector()
{ {
return GridwiseGemmPipeline_v2{}; return GridwiseGemmPipeline_v2{};
} }
else if constexpr(PipelineVer == PipelineVersion::dequant_v1)
{
return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{};
}
else if constexpr(PipelineVer == PipelineVersion::weight_only) else if constexpr(PipelineVer == PipelineVersion::weight_only)
{ {
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{}; return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
......
...@@ -550,225 +550,6 @@ struct GridwiseGemmPipeline_v1<1, false, false> ...@@ -550,225 +550,6 @@ struct GridwiseGemmPipeline_v1<1, false, false>
} }
}; };
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_dequant;
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
scale_blockwise_copy.RunRead(scale_grid_desc, scale_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
scale_blockwise_copy.RunWrite(scale_block_desc, scale_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
}
}
};
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
scale_blockwise_copy.Run(
scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds> template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_WeightOnly; struct GridwiseGemmPipeline_v1_WeightOnly;
......
...@@ -114,7 +114,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -114,7 +114,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
__device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc, const Index& scale_slice_origin_idx) __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc,
const Index& scale_slice_origin_idx)
{ {
scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx); scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx);
} }
...@@ -274,8 +275,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -274,8 +275,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
} }
template <typename ScaleBuffer> template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
const ScaleBuffer& scale_buf)
{ {
static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -358,11 +358,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -358,11 +358,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
scale_scalar_per_access; scale_scalar_per_access;
}(); }();
constexpr auto scale_data_idx_seq = generate_sequence_v2( constexpr auto scale_data_idx_seq =
[&](auto i) { return Number<scale_data_idx[i]>{}; }, Number<scale_data_idx.Size()>{}); generate_sequence_v2([&](auto i) { return Number<scale_data_idx[i]>{}; },
Number<scale_data_idx.Size()>{});
const bool is_scale_valid = const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
coordinate_has_valid_offset_assuming_visible_index_is_valid(scale_desc, scale_coord_); scale_desc, scale_coord_);
using scale_vector_type = vector_type_maker_t<ScaleData, ScaleScalarPerVector>; using scale_vector_type = vector_type_maker_t<ScaleData, ScaleScalarPerVector>;
using scale_vector_t = typename scale_vector_type::type; using scale_vector_t = typename scale_vector_type::type;
...@@ -372,8 +373,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -372,8 +373,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)}; scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)};
// copy data from scale_vector_container into scale_thread_scratch_ // copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_ scale_thread_scratch_.template SetAsType<scale_vector_t>(
.template SetAsType<scale_vector_t>(
scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]); scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -381,7 +381,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -381,7 +381,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; move_on_dim_(i) =
ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= move_on_dim_(i) &=
...@@ -399,13 +400,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -399,13 +400,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_tensor_coordinate( move_tensor_coordinate(scale_desc,
scale_desc, scale_coord_, scale_forward_steps[scale_dim_access_order[i]]); scale_coord_,
scale_forward_steps[scale_dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(scale_desc,
scale_desc, scale_coord_, scale_backward_steps[scale_dim_access_order[i]]); scale_coord_,
scale_backward_steps[scale_dim_access_order[i]]);
} }
} }
}); });
...@@ -500,20 +503,46 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -500,20 +503,46 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
// do data transpose // do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}( transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs); src_vector_refs, dst_vector_refs);
// do fast numeric convert
src_converted_thread_scratch_.template SetAsType<SrcThreadConvertedScratch::V>(access_idx,
fast_numeric_converter(
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<SrcThreadScratch::V>(access_idx)));
}); });
} }
// Do fast numeric convert
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using src_converted_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using src_converted_vector_t = typename src_converted_vector_type::type;
// Vector-wise type convert
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
auto src_vector_container = src_vector_type{
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<src_vector_t>(
access_idx)};
auto src_converted_vector_container =
src_converted_vector_type{fast_numeric_converter(src_vector_container)};
src_converted_thread_scratch_.template SetAsType<src_converted_vector_t>(
access_idx,
src_converted_vector_container.template AsType<src_converted_vector_t>()[I0]);
});
// Element-scale operation, expect packed multiplication
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
// Scale is dynamic, could not implement through element_op.
DstData dst_v; DstData dst_v;
constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{}; constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{};
src_element_op_(dst_v, src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_(dst_v,
src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]);
dst_thread_scratch_(idx) = dst_v; dst_thread_scratch_(idx) = dst_v;
}); });
#endif #endif
...@@ -978,13 +1007,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -978,13 +1007,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto scale_thread_scratch_desc_ = decltype(GetScaleThreadScratchDescriptor()){}; static constexpr auto scale_thread_scratch_desc_ =
decltype(GetScaleThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
/* /*
template <bool kLastDim> template <bool kLastDim>
struct ScaleThreadScratchDesc{}; struct ScaleThreadScratchDesc{};
*/ */
// Registers, contain raw data loaded from global buffer // Registers, contain raw data loaded from global buffer
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
...@@ -994,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -994,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
true>; true>;
// Registers, contain fast converted data // Registers, contain fast converted data
using SrcThreadConvertedScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadConvertedScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
...@@ -1014,7 +1045,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -1014,7 +1045,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
decltype(dst_thread_scratch_desc_), decltype(dst_thread_scratch_desc_),
true>; true>;
using FastTypeConverter = tensor_operation::element_wise::FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>; using FastTypeConverter = tensor_operation::element_wise::
FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
SrcThreadConvertedScratch src_converted_thread_scratch_; SrcThreadConvertedScratch src_converted_thread_scratch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment