Unverified Commit cc0ffeb7 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

Merge pull request #851 from ROCmSoftwarePlatform/perf_opt_fpAintB

New implementation fpAintB
parents 3ba0f0d7 bf75259f
......@@ -35,16 +35,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp,
CElementOp,
GemmDefault,
2, // Prefetch stage
1, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // NPerBlock
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......
......@@ -21,7 +21,7 @@ using QuantDataType = int8_t;
using BDataType = uint8_t;
using ScaleDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using ALayout = Row;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
namespace ck {
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ScaleDataType,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
typename ScaleBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct Blockwise_fpAintB_GemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
// As Float DataType
static constexpr auto wmma_gemm =
WmmaGemm<ADataType, ADataType, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
Blockwise_fpAintB_GemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin), scale_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
static constexpr ScaleBlockDesc scale_block_desc_1_n0_n1_n2_1;
template <typename ABlockBuffer,
typename BBlockBuffer,
typename ScaleBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
const ScaleBlockBuffer& scale_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_thread_desc_.GetElementSpaceSize());
auto scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_thread_desc_.GetElementSpaceSize());
// auto converted_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
// b_thread_desc_.GetElementSpaceSize());
tensor_operation::element_wise::FastNumericArrayConverter<BDataType, ADataType, WmmaK>
fast_numeric_converter;
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
vector_type<BDataType, WmmaK> b_int_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_int_vec.template AsType<BDataType>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
});
// convert B from uint8 to fp16, multiply scale
b_thread_vec = fast_numeric_converter(b_int_vec);
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<ADataType>()(i) =
scale_thread_buf[n0] *
b_thread_vec.template AsType<ADataType>()(i);
});
vector_type<ADataType, WmmaK> a_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<ADataType>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read weight scale
scale_thread_copy_.Run(scale_block_desc_1_n0_n1_n2_1,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<BDataType, WmmaK> b_int_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_int_vec.template AsType<BDataType>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
i / B_K1 / B_KRow, n0, 0, (i / B_K1) % B_KRow, 0, i % B_K1))>{}];
});
// convert B from uint8 to fp16, multiply scale
b_thread_vec = fast_numeric_converter(b_int_vec);
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<ADataType>()(i) =
scale_thread_buf[n0] * b_thread_vec.template AsType<ADataType>()(i);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<ADataType, WmmaK> a_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<ADataType>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
static constexpr auto scale_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(
Number<WmmaK / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, Number<B_KRow>{}, I1, I1),
make_tuple(I0, I1, I0, I0, I0, I0));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ADataType,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
ADataType,
ADataType,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
0x76543210,
0xfedcba98,
TransposeC ? false : true>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<BDataType,
BDataType,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
BDataType,
BDataType,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
0x76543210,
0xfedcba98,
TransposeC ? true : false>;
};
template <bool EnableLds>
struct ScaleThreadCopySelector;
template <>
struct ScaleThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, 1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1>;
};
template <>
struct ScaleThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
typename ScaleThreadCopySelector<BEnableLds>::type scale_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
namespace ck {
/**
* @brief Blockwise data transfer with dequantization
*
* RunRead would load low-precision data and scale data.
* RunWrite would process dequantization process.
* Assume Scale is identical along K-dimension
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename SrcElementwiseOperation,
typename ScaleElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename BlockScaleSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename ScaleData,
typename DstData,
typename SrcDesc,
typename ScaleDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t ScaleScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t ScaleScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_dequant
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto scale_thread_slice_lengths =
BlockScaleSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const ScaleDesc& scale_desc,
const Index& scale_block_slice_origin,
const ScaleElementwiseOperation& scale_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
scale_desc,
make_zero_multi_index<nDim>(),
scale_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<ScaleDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
is_same<BlockScaleSliceLengths,
decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetScaleSliceOrigin(
scale_desc, scale_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
// With the assumption, scale scratch is always one
template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
// We don't prefer use this API directly
/*
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
*/
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
// With the assumption, scale buffer don't need move slice window method
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1_dequant<decltype(thread_slice_lengths),
decltype(scale_thread_slice_lengths),
SrcElementwiseOperation,
ScaleElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
ScaleData,
DstData,
SrcDesc,
ScaleDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
ScaleScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
ScaleScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
......@@ -66,7 +66,7 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::dequant_v1>
ck::PipelineVersion PipelineVer = ck::PipelineVersion::weight_only>
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
BLayout,
CLayout,
......@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
// LDS bypass feature not checked.
// LDS bypass feature not implemented for dequantization pipeline.
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
......@@ -677,7 +677,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"},
{PipelineVersion::dequant_v1, "dequant_v1"}};
{PipelineVersion::weight_only, "weight_only"}};
// clang-format off
str << "DeviceFpAintBGemm_Wmma_CShuffle"
......
......@@ -405,10 +405,10 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n"
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n"
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
......
......@@ -9,8 +9,9 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -82,6 +83,7 @@ __global__ void
#endif // end of if (defined(__gfx1100__))
}
// Assume B is Col-Major
template <index_t BlockSize,
typename ADataType,
typename BDataType,
......@@ -129,7 +131,7 @@ template <index_t BlockSize,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::dequant_v1>
PipelineVersion PipelineVer = PipelineVersion::weight_only>
struct GridwiseFpAintBGemm_Wmma
{
static constexpr auto I0 = Number<0>{};
......@@ -252,38 +254,6 @@ struct GridwiseFpAintBGemm_Wmma
return b_block_desc;
}
__host__ __device__ static constexpr auto MakeScaleBlockDescriptor()
{
// Scale [1, N], all K related dimension reduce to 1
constexpr auto scale_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, I1),
make_tuple(I0, I1, I0));
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
I1),
make_tuple(I0, I1, I0, I0, I0, I0, I0));
}
}();
return scale_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
......@@ -424,47 +394,6 @@ struct GridwiseFpAintBGemm_Wmma
return b_wave_desc;
}
template <typename ScaleBlockDesc_>
__host__ __device__ static constexpr auto MakeScaleWaveDescriptor(const ScaleBlockDesc_&)
{
constexpr auto scale_wave_desc = [&]() {
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = ScaleBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
ScaleBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ScaleBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ScaleBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = ScaleBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return make_naive_tensor_descriptor(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(I0, I1, I0, I0, I0, I0));
}
}();
return scale_wave_desc;
}
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
......@@ -613,8 +542,11 @@ struct GridwiseFpAintBGemm_Wmma
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and Dequantized B: be careful of DataType
// scale would not put into LDS.
using LDS_ADataType = ADataType;
using LDS_BDataType = ADataType;
using LDS_CDataType = CShuffleDataType;
static constexpr auto max_lds_align = K1;
static constexpr auto a_block_space_size_aligned =
......@@ -625,18 +557,13 @@ struct GridwiseFpAintBGemm_Wmma
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto scale_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
MakeScaleBlockDescriptor().GetElementSpaceSize(), max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0;
// B would be dequantize to ADataType before enter LDS
// b_lds_offset = LDS size allocated for a in byte / LDS_BDataType
static constexpr auto b_block_space_offset =
(a_block_space_offset + a_block_space_size_aligned) * sizeof(ADataType) /
sizeof(BDataType);
static constexpr auto scale_block_space_offset =
(b_block_space_offset + b_block_space_size_aligned) * sizeof(BDataType) /
sizeof(ScaleDataType);
(a_block_space_offset + a_block_space_size_aligned) * sizeof(LDS_ADataType) /
sizeof(LDS_BDataType);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size =
......@@ -646,10 +573,9 @@ struct GridwiseFpAintBGemm_Wmma
static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType) +
scale_block_space_size_aligned * sizeof(ScaleDataType));
math::max(c_shuffle_block_space_size * sizeof(LDS_CDataType),
a_block_space_size_aligned * sizeof(LDS_ADataType) +
b_block_space_size_aligned * sizeof(LDS_BDataType));
};
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
......@@ -707,7 +633,6 @@ struct GridwiseFpAintBGemm_Wmma
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = MakeBBlockDescriptor();
constexpr auto scale_block_desc = MakeScaleBlockDescriptor();
auto a_block_trait = [&](){
// A matrix blockwise copy
......@@ -795,35 +720,44 @@ struct GridwiseFpAintBGemm_Wmma
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
static_cast<ADataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned);
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
ThreadGroupTensorSliceTransfer_v4r1_dequant<ThisThreadBlock,
/* typename SrcElementwiseOperation, */ BElementwiseOperation,
/* typename ScaleElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, NPerBlock, K1>,
/* typename BlockScaleSliceLengths, */ Sequence<K0PerBlock, NPerBlock, I1>,
/* typename ThreadClusterLengths, */ BBlockTransferThreadClusterLengths_K0_N_K1,
/* typename ThreadClusterArrangeOrder, */ BBlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ BDataType,
/* typename ScaleData, */ ScaleDataType,
/* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(b_grid_desc),
/* typename ScaleDesc, */ decltype(scale_grid_desc),
/* typename DstDesc, */ decltype(b_block_desc),
/* typename SrcDimAccessOrder, */ BBlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ BBlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ BBlockTransferSrcScalarPerVector,
/* index_t ScaleScalarPerVector, */ 1,
/* index_t DstScalarPerVector, */ BBlockTransferDstScalarPerVector_K1,
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t ScaleScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ BThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
NumGemmKPrefetchStage>(
b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
scale_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
ck::tensor_operation::element_wise::PassThrough{},
b_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
......@@ -870,108 +804,22 @@ struct GridwiseFpAintBGemm_Wmma
}
};
auto scale_block_trait = [&](){
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset,
SharedMemTrait::scale_block_space_size_aligned);
auto scale_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
// Reduce slice length K1 to 1
Sequence<K0PerBlock, NPerBlock, I1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
ScaleDataType,
ScaleDataType,
decltype(scale_grid_desc),
decltype(scale_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
1,
1,
1, // no effect
1, // no effect
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
scale_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
scale_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(scale_block_buf, scale_blockwise_copy);
}
else
{
// Thread-wise copy
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
auto scale_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_block_desc.GetElementSpaceSize());
auto scale_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
ScaleDataType,
decltype(scale_grid_desc),
decltype(scale_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
scale_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(scale_block_buf, scale_blockwise_copy);
}
};
auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1];
auto b_block_buf = b_block_trait()[I0];
auto b_blockwise_copy = b_block_trait()[I1];
auto scale_block_buf = scale_block_trait()[I0];
auto scale_blockwise_copy = scale_block_trait()[I1];
/*******************************************************************************/
// GEMM
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
Blockwise_fpAintB_GemmWMMA<BlockSize,
BlockwiseGemmWMMA<BlockSize,
ADataType,
BDataType,
ScaleDataType,
ADataType, //Dequantized
AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBWaveDescriptor(b_block_desc)),
decltype(MakeScaleWaveDescriptor(scale_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1006,10 +854,7 @@ struct GridwiseFpAintBGemm_Wmma
b_block_buf,
b_block_slice_copy_step,
scale_grid_desc,
scale_block_desc,
scale_blockwise_copy,
scale_grid_buf,
scale_block_buf,
blockwise_gemm,
c_thread_buf,
KBlockMainLoop);
......
......@@ -12,7 +12,7 @@ enum struct PipelineVersion
{
v1,
v2,
dequant_v1,
weight_only,
};
template <PipelineVersion PipelineVer,
......@@ -37,9 +37,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
return GridwiseGemmPipeline_v2{};
}
else if constexpr(PipelineVer == PipelineVersion::dequant_v1)
else if constexpr(PipelineVer == PipelineVersion::weight_only)
{
return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{};
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
}
else
{
......
......@@ -551,10 +551,10 @@ struct GridwiseGemmPipeline_v1<1, false, false>
};
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1_dequant;
struct GridwiseGemmPipeline_v1_WeightOnly;
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, true>
struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -580,10 +580,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
......@@ -599,18 +596,16 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
// Global Prefetch Stage 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
scale_blockwise_copy.RunRead(scale_grid_desc, scale_grid_buf);
// Scale read once
b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
......@@ -619,8 +614,8 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// Dequantization fused in blockwise_copy
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
scale_blockwise_copy.RunWrite(scale_block_desc, scale_block_buf);
// main body
if constexpr(HasMainLoop)
......@@ -635,7 +630,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
......@@ -653,118 +648,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
}
}
};
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
scale_blockwise_copy.Run(
scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst_idle
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
// 5. Dequantization happened between read and write.
template <typename SliceLengths,
typename ScaleSliceLengths,
typename SrcElementwiseOperation,
typename ScaleElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SrcData,
typename ScaleData,
typename DstData,
typename SrcDesc,
typename ScaleDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t ScaleScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t ScaleScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1_dequant
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const ScaleDesc& scale_desc,
const Index& scale_slice_origin,
const ScaleElementwiseOperation& scale_element_op,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
scale_element_op_(scale_element_op),
dst_element_op_(dst_element_op)
{
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc,
const Index& scale_slice_origin_idx)
{
scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
{
static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename ScaleBuffer::type>, remove_cvref_t<ScaleData>>::value,
"wrong! ScaleBuffer and ScaleData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scale_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, ScaleScalarPerVector>{}, Number<nDim>{});
constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
constexpr auto scale_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_scale_access_lengths =
container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order);
// make forward steps
const auto scale_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(scale_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto scale_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(scale_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_scale_access_lengths)>{}([&](auto ordered_scale_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_scale_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate scale data index
constexpr auto scale_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i]
: ordered_scale_access_lengths[i] - 1 -
ordered_scale_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) *
scale_scalar_per_access;
}();
constexpr auto scale_data_idx_seq =
generate_sequence_v2([&](auto i) { return Number<scale_data_idx[i]>{}; },
Number<scale_data_idx.Size()>{});
const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
scale_desc, scale_coord_);
using scale_vector_type = vector_type_maker_t<ScaleData, ScaleScalarPerVector>;
using scale_vector_t = typename scale_vector_type::type;
// copy data from scale_buf into scale_vector_container
auto scale_vector_container = scale_vector_type{
scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)};
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_.template SetAsType<scale_vector_t>(
scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) =
ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move scale coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(scale_desc,
scale_coord_,
scale_forward_steps[scale_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(scale_desc,
scale_coord_,
scale_backward_steps[scale_dim_access_order[i]]);
}
}
});
});
// don't need to move scale coordinate back to slice origin
/*
if constexpr(SrcResetCoordinateAfterRun)
{
const auto scale_reset_step =
make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
}
*/
}
template <index_t ThreadScratchId>
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<SrcData>>::value &&
is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert(SrcVectorDim != DstVectorDim, "wrong");
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
// Do fast numeric convert
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using src_converted_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using src_converted_vector_t = typename src_converted_vector_type::type;
// Vector-wise type convert
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
auto src_vector_container = src_vector_type{
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<src_vector_t>(
access_idx)};
auto src_converted_vector_container =
src_converted_vector_type{fast_numeric_converter(src_vector_container)};
src_converted_thread_scratch_.template SetAsType<src_converted_vector_t>(
access_idx,
src_converted_vector_container.template AsType<src_converted_vector_t>()[I0]);
});
// Element-scale operation, expect packed multiplication
static_ford<SliceLengths>{}([&](auto idx) {
DstData dst_v;
constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{};
// printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_(dst_v,
src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]);
dst_thread_scratch_(idx) = dst_v;
});
#endif
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
DstData dst_v;
// apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
dst_vector_container.template AsType<DstData>()(i) = dst_v;
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetScaleThreadScratchDescriptor()
{
constexpr auto scale_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, ScaleScalarPerVector>{}, Number<nDim>{});
constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
constexpr auto scale_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(scale_access_lengths), Number<ScaleScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(scale_access_lengths_and_vector_length[i],
scale_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(scale_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto scale_thread_scratch_desc_ =
decltype(GetScaleThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
/*
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
*/
// Registers, contain raw data loaded from global buffer
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
// Registers, contain fast converted data
using SrcThreadConvertedScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
// Registers, contain scale data
using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
ScaleData,
ScaleScalarPerVector,
decltype(scale_thread_scratch_desc_),
true>;
// Registers, contain dequantized data
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(dst_thread_scratch_desc_),
true>;
using FastTypeConverter = tensor_operation::element_wise::
FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
SrcThreadConvertedScratch src_converted_thread_scratch_;
ScaleThreadScratch scale_thread_scratch_;
DstThreadScratch dst_thread_scratch_;
FastTypeConverter fast_numeric_converter;
SrcCoord src_coord_;
ScaleCoord scale_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const ScaleElementwiseOperation scale_element_op_;
const DstElementwiseOperation dst_element_op_;
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment