Commit bf75259f authored by aska-0096's avatar aska-0096
Browse files

New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm

parent 061009a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
namespace ck {
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ScaleDataType,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
typename ScaleBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct Blockwise_fpAintB_GemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
// As Float DataType
static constexpr auto wmma_gemm =
WmmaGemm<ADataType, ADataType, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
Blockwise_fpAintB_GemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin), scale_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
static constexpr ScaleBlockDesc scale_block_desc_1_n0_n1_n2_1;
template <typename ABlockBuffer,
typename BBlockBuffer,
typename ScaleBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
const ScaleBlockBuffer& scale_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_thread_desc_.GetElementSpaceSize());
auto scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_thread_desc_.GetElementSpaceSize());
// auto converted_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
// b_thread_desc_.GetElementSpaceSize());
tensor_operation::element_wise::FastNumericArrayConverter<BDataType, ADataType, WmmaK>
fast_numeric_converter;
// basic intrinsic to determine loopover direction
if constexpr( 0 )
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
vector_type<BDataType, WmmaK> b_int_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_int_vec.template AsType<BDataType>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
});
// convert B from uint8 to fp16, multiply scale
b_thread_vec = fast_numeric_converter(b_int_vec);
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<ADataType>()(i) =
scale_thread_buf[n0] *
b_thread_vec.template AsType<ADataType>()(i);
});
vector_type<ADataType, WmmaK> a_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<ADataType>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read weight scale
scale_thread_copy_.Run(scale_block_desc_1_n0_n1_n2_1,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<BDataType, WmmaK> b_int_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_int_vec.template AsType<BDataType>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
i / B_K1 / B_KRow, n0, 0, (i / B_K1) % B_KRow, 0, i % B_K1))>{}];
});
// convert B from uint8 to fp16, multiply scale
b_thread_vec = fast_numeric_converter(b_int_vec);
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<ADataType>()(i) =
scale_thread_buf[n0] * b_thread_vec.template AsType<ADataType>()(i);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<ADataType, WmmaK> a_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<ADataType>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
static constexpr auto scale_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(
Number<WmmaK / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, Number<B_KRow>{}, I1, I1),
make_tuple(I0, I1, I0, I0, I0, I0));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ADataType,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
ADataType,
ADataType,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
0x76543210,
0xfedcba98,
TransposeC ? false : true>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<BDataType,
BDataType,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
BDataType,
BDataType,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
0x76543210,
0xfedcba98,
TransposeC ? true : false>;
};
template <bool EnableLds>
struct ScaleThreadCopySelector;
template <>
struct ScaleThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, 1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1>;
};
template <>
struct ScaleThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
typename ScaleThreadCopySelector<BEnableLds>::type scale_thread_copy_;
};
} // namespace ck
...@@ -57,7 +57,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -57,7 +57,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto scale_thread_slice_lengths = BlockScaleSliceLengths{} / ThreadClusterLengths{}; static constexpr auto scale_thread_slice_lengths =
BlockScaleSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -92,7 +93,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -92,7 +93,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} && is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
is_same<BlockScaleSliceLengths, decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{} , is_same<BlockScaleSliceLengths,
decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
...@@ -108,8 +110,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -108,8 +110,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin); src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetScaleSliceOrigin(scale_desc, threadwise_transfer_.SetScaleSliceOrigin(
scale_block_slice_origin + thread_data_idx_begin); scale_desc, scale_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc, threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin); dst_block_slice_origin + thread_data_idx_begin);
} }
...@@ -129,8 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant ...@@ -129,8 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_dequant
// With the assumption, scale scratch is always one // With the assumption, scale scratch is always one
template <typename ScaleBuffer> template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
const ScaleBuffer& scale_buf)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
......
...@@ -677,7 +677,6 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -677,7 +677,6 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std::map<PipelineVersion, std::string> PipelineVersionToString{ std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}, {PipelineVersion::v2, "v2"},
{PipelineVersion::dequant_v1, "dequant_v1"},
{PipelineVersion::weight_only, "weight_only"}}; {PipelineVersion::weight_only, "weight_only"}};
// clang-format off // clang-format off
......
...@@ -404,18 +404,11 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> ...@@ -404,18 +404,11 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
// static constexpr ck::half_t fp16_subtract = -1152;
// Output.template AsType<ck::half_t>()(Number<0>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<1>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<2>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<3>{}) += fp16_subtract;
// inline assembly get very poor performance as no chance to global scheduling
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0]) : "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1]) : "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
......
...@@ -12,7 +12,6 @@ enum struct PipelineVersion ...@@ -12,7 +12,6 @@ enum struct PipelineVersion
{ {
v1, v1,
v2, v2,
dequant_v1,
weight_only, weight_only,
}; };
...@@ -38,10 +37,6 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -38,10 +37,6 @@ constexpr auto GridwiseGemmPipeline_Selector()
{ {
return GridwiseGemmPipeline_v2{}; return GridwiseGemmPipeline_v2{};
} }
else if constexpr(PipelineVer == PipelineVersion::dequant_v1)
{
return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{};
}
else if constexpr(PipelineVer == PipelineVersion::weight_only) else if constexpr(PipelineVer == PipelineVersion::weight_only)
{ {
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{}; return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
......
...@@ -550,225 +550,6 @@ struct GridwiseGemmPipeline_v1<1, false, false> ...@@ -550,225 +550,6 @@ struct GridwiseGemmPipeline_v1<1, false, false>
} }
}; };
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_dequant;
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
scale_blockwise_copy.RunRead(scale_grid_desc, scale_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
scale_blockwise_copy.RunWrite(scale_block_desc, scale_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
}
}
};
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
scale_blockwise_copy.Run(
scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds> template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_WeightOnly; struct GridwiseGemmPipeline_v1_WeightOnly;
......
...@@ -114,7 +114,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -114,7 +114,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
__device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc, const Index& scale_slice_origin_idx) __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc,
const Index& scale_slice_origin_idx)
{ {
scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx); scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx);
} }
...@@ -274,8 +275,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -274,8 +275,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
} }
template <typename ScaleBuffer> template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
const ScaleBuffer& scale_buf)
{ {
static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -358,11 +358,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -358,11 +358,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
scale_scalar_per_access; scale_scalar_per_access;
}(); }();
constexpr auto scale_data_idx_seq = generate_sequence_v2( constexpr auto scale_data_idx_seq =
[&](auto i) { return Number<scale_data_idx[i]>{}; }, Number<scale_data_idx.Size()>{}); generate_sequence_v2([&](auto i) { return Number<scale_data_idx[i]>{}; },
Number<scale_data_idx.Size()>{});
const bool is_scale_valid = const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
coordinate_has_valid_offset_assuming_visible_index_is_valid(scale_desc, scale_coord_); scale_desc, scale_coord_);
using scale_vector_type = vector_type_maker_t<ScaleData, ScaleScalarPerVector>; using scale_vector_type = vector_type_maker_t<ScaleData, ScaleScalarPerVector>;
using scale_vector_t = typename scale_vector_type::type; using scale_vector_t = typename scale_vector_type::type;
...@@ -372,8 +373,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -372,8 +373,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)}; scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)};
// copy data from scale_vector_container into scale_thread_scratch_ // copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_ scale_thread_scratch_.template SetAsType<scale_vector_t>(
.template SetAsType<scale_vector_t>(
scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]); scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -381,7 +381,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -381,7 +381,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; move_on_dim_(i) =
ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= move_on_dim_(i) &=
...@@ -399,13 +400,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -399,13 +400,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_tensor_coordinate( move_tensor_coordinate(scale_desc,
scale_desc, scale_coord_, scale_forward_steps[scale_dim_access_order[i]]); scale_coord_,
scale_forward_steps[scale_dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(scale_desc,
scale_desc, scale_coord_, scale_backward_steps[scale_dim_access_order[i]]); scale_coord_,
scale_backward_steps[scale_dim_access_order[i]]);
} }
} }
}); });
...@@ -500,20 +503,46 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -500,20 +503,46 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
// do data transpose // do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}( transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs); src_vector_refs, dst_vector_refs);
// do fast numeric convert
src_converted_thread_scratch_.template SetAsType<SrcThreadConvertedScratch::V>(access_idx,
fast_numeric_converter(
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<SrcThreadScratch::V>(access_idx)));
}); });
} }
// Do fast numeric convert
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using src_converted_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using src_converted_vector_t = typename src_converted_vector_type::type;
// Vector-wise type convert
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
auto src_vector_container = src_vector_type{
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<src_vector_t>(
access_idx)};
auto src_converted_vector_container =
src_converted_vector_type{fast_numeric_converter(src_vector_container)};
src_converted_thread_scratch_.template SetAsType<src_converted_vector_t>(
access_idx,
src_converted_vector_container.template AsType<src_converted_vector_t>()[I0]);
});
// Element-scale operation, expect packed multiplication
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
// Scale is dynamic, could not implement through element_op.
DstData dst_v; DstData dst_v;
constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{}; constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{};
src_element_op_(dst_v, src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_(dst_v,
src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]);
dst_thread_scratch_(idx) = dst_v; dst_thread_scratch_(idx) = dst_v;
}); });
#endif #endif
...@@ -978,13 +1007,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -978,13 +1007,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto scale_thread_scratch_desc_ = decltype(GetScaleThreadScratchDescriptor()){}; static constexpr auto scale_thread_scratch_desc_ =
decltype(GetScaleThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
/* /*
template <bool kLastDim> template <bool kLastDim>
struct ScaleThreadScratchDesc{}; struct ScaleThreadScratchDesc{};
*/ */
// Registers, contain raw data loaded from global buffer // Registers, contain raw data loaded from global buffer
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
...@@ -994,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -994,7 +1024,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
true>; true>;
// Registers, contain fast converted data // Registers, contain fast converted data
using SrcThreadConvertedScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadConvertedScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
...@@ -1014,7 +1045,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant ...@@ -1014,7 +1045,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
decltype(dst_thread_scratch_desc_), decltype(dst_thread_scratch_desc_),
true>; true>;
using FastTypeConverter = tensor_operation::element_wise::FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>; using FastTypeConverter = tensor_operation::element_wise::
FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
SrcThreadConvertedScratch src_converted_thread_scratch_; SrcThreadConvertedScratch src_converted_thread_scratch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment