"vscode:/vscode.git/clone" did not exist on "e1a3fff67510be2af023b31587e411230b994631"
Commit 0c51a35e authored by aska-0096's avatar aska-0096
Browse files

fpAintB kernel compile pass

parent febd76e4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp"
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
......@@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
BElementOp,
CElementOp,
GemmDefault,
2, // Prefetch stage
1, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // NPerBlock
......@@ -67,8 +67,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm<ADataType,
BDataType,
ScaleDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc"
......
......@@ -27,6 +27,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
// assume scale tensor is [1, n]
Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, BLayout{}));
switch(config.init_method)
{
......@@ -34,26 +36,32 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-5.f, 5.f}(scale_k_n);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-5.f, 5.f}(scale_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{1.f, 1.f}(scale_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-2.f, 2.f}(scale_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......@@ -61,6 +69,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
......@@ -77,10 +86,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
scale_k_n_device_buf.ToDevice(scale_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
......@@ -98,6 +109,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(scale_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
......@@ -136,7 +148,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
a_m_k, b_k_n, scale_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
......
......@@ -20,7 +20,7 @@ template <index_t BlockSize,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
typename ScaleDesc,
typename ScaleBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -73,8 +73,9 @@ struct Blockwise_fpAintB_GemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
// As Float DataType
static constexpr auto wmma_gemm =
WmmaGemm<ADataType, BDataType, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
WmmaGemm<ADataType, ADataType, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......@@ -178,9 +179,10 @@ struct Blockwise_fpAintB_GemmWMMA
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
__host__ __device__
Blockwise_fpAintB_GemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
: a_thread_copy_(a_origin), b_thread_copy_(b_origin), scale_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
......@@ -290,8 +292,12 @@ struct Blockwise_fpAintB_GemmWMMA
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
static constexpr ScaleBlockDesc scale_block_desc_1_n0_n1_n2_1;
template <typename ABlockBuffer, typename BBlockBuffer, typename ScaleBlockBuffer, typename CThreadBuffer>
template <typename ABlockBuffer,
typename BBlockBuffer,
typename ScaleBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
const ScaleBlockBuffer& scale_block_buf,
......@@ -305,8 +311,6 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_desc_.GetElementSpaceSize());
auto converted_b_thread_buf = b_thread_buf;
static constexpr auto dequantizer = Dequantizer<ADataType, BDataType>{};
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
......@@ -333,21 +337,22 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_scale_thread_desc_,
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_scale_thread_buf);
scale_thread_buf);
// convert B from int8 to fp16
converted_b_thread_buf = type_convert(b_thread_buf);
// multiply scale
dequantize(converted_b_thread_buf, scale_thread_buf);
// convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) =
scale_thread_buf[i / WmmaK] *
type_convert<ADataType>(b_thread_buf[i]);
});
vector_type<ADataType, WmmaK> a_thread_vec;
vector_type<BDataType, WmmaK> b_thread_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<ADataType>()(i) =
......@@ -358,7 +363,7 @@ struct Blockwise_fpAintB_GemmWMMA
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
b_thread_vec.template AsType<BDataType>()(i) =
b_thread_vec.template AsType<ADataType>()(i) =
converted_b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
......@@ -369,7 +374,7 @@ struct Blockwise_fpAintB_GemmWMMA
});
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<BDataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -396,6 +401,20 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
// convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] *
type_convert<ADataType>(b_thread_buf[i]);
});
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
......@@ -406,11 +425,11 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_buf);
vector_type<ADataType, WmmaK> a_thread_vec;
vector_type<BDataType, WmmaK> b_thread_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<BDataType>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
b_thread_vec.template AsType<ADataType>()(i) =
converted_b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
......@@ -428,7 +447,7 @@ struct Blockwise_fpAintB_GemmWMMA
});
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<BDataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -472,6 +491,15 @@ struct Blockwise_fpAintB_GemmWMMA
Number<B_K1>{},
Number<1>{}));
static constexpr auto scale_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(I0, I1, I0, I0, I0, I0));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
......@@ -548,8 +576,42 @@ struct Blockwise_fpAintB_GemmWMMA
TransposeC ? true : false>;
};
template <bool EnableLds>
struct ScaleThreadCopySelector;
template <>
struct ScaleThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct ScaleThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
typename ScaleThreadCopySelector<BEnableLds>::type scale_thread_copy_;
};
} // namespace ck
......@@ -66,7 +66,7 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
ck::PipelineVersion PipelineVer = ck::PipelineVersion::dequant_v1>
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
BLayout,
CLayout,
......@@ -183,7 +183,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
// When K = 1, it might be scale tensor.
assert(K % K1 == 0 && K != 1 );
assert(K % K1 == 0 && K != 1);
if constexpr(BEnableLds)
{
......@@ -241,8 +241,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseFpAintBGemm_Wmma<BlockSize,
using GridwiseGemm = GridwiseFpAintBGemm_Wmma<
BlockSize,
ADataType,
BDataType,
ScaleDataType,
......@@ -252,6 +252,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
InMemoryDataOperationEnum::Set,
AGridDesc,
BGridDesc,
ScaleGridDesc,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
......@@ -295,7 +296,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
const ScaleDataType* p_scale,
const ScaleDataType* p_scale_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
......@@ -310,7 +311,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_scale_grid_{p_scale},
p_scale_grid_{p_scale_grid},
p_c_grid_{p_c_grid},
a_grid_desc_{},
b_grid_desc_{},
......@@ -329,7 +330,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(1, N, 1);
scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, 0);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
......@@ -347,7 +348,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const ScaleDataType* p_b_grid_;
const ScaleDataType* p_scale_grid_;
CDataType* p_c_grid_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
......@@ -406,7 +407,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
CDataType,
remove_reference_t<DeviceOp::AGridDesc>,
remove_reference_t<DeviceOp::BGridDesc>,
remove_reference_t<DeviceOp::BGridDesc>,
remove_reference_t<DeviceOp::ScaleGridDesc>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
......@@ -422,9 +423,11 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_scale_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.scale_grid_desc_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
......@@ -536,10 +539,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_);
}
// polymorphic
......@@ -550,6 +551,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const ScaleDataType* p_scale,
CDataType* p_c,
index_t M,
index_t N,
......@@ -563,6 +565,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
return Argument{p_a,
p_b,
p_scale,
p_c,
M,
N,
......@@ -582,6 +585,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_scale,
void* p_c,
index_t M,
index_t N,
......@@ -595,6 +599,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const ScaleDataType*>(p_scale),
static_cast<CDataType*>(p_c),
M,
N,
......@@ -623,8 +628,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"},
{PipelineVersion::dequant_v1, "dequant_v1"}};
// clang-format off
str << "DeviceFpAintBGemm_Wmma_CShuffle"
......
......@@ -20,9 +20,11 @@ namespace ck {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename ScaleDataType,
typename CDataType,
typename AGridDesc,
typename BGridDesc,
typename ScaleGridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -33,11 +35,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
const ScaleDataType* __restrict__ p_scale_grid,
CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc,
const BGridDesc b_grid_desc,
const ScaleGridDesc scale_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
......@@ -51,10 +55,12 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_scale_grid,
p_c_grid,
p_shared,
a_grid_desc,
b_grid_desc,
scale_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
......@@ -63,9 +69,11 @@ __global__ void
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_scale_grid;
ignore = p_c_grid;
ignore = a_grid_desc;
ignore = b_grid_desc;
ignore = scale_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
......@@ -77,12 +85,14 @@ __global__ void
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ScaleDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc,
typename BGridDesc,
typename ScaleGridDesc,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -119,7 +129,7 @@ template <index_t BlockSize,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::dequant_v1>
struct GridwiseFpAintBGemm_Wmma
{
static constexpr auto I0 = Number<0>{};
......@@ -140,7 +150,12 @@ struct GridwiseFpAintBGemm_Wmma
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1_dequant<NumGemmKPrefetchStage, AEnableLds, BEnableLds>;
using GridwiseGemmPipe =
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopSched,
AEnableLds,
BEnableLds>())>;
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__ __device__ static constexpr auto MakeABlockDescriptor()
......@@ -237,6 +252,38 @@ struct GridwiseFpAintBGemm_Wmma
return b_block_desc;
}
__host__ __device__ static constexpr auto MakeScaleBlockDescriptor()
{
// Scale [1, N], all K related dimension reduce to 1
constexpr auto scale_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(I0, I1, I0));
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(I0, I1, I0, I0, I0, I0, I0));
}
}();
return scale_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
......@@ -537,9 +584,15 @@ struct GridwiseFpAintBGemm_Wmma
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto scale_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
MakeScaleBlockDescriptor().GetElementSpaceSize(), max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
static constexpr auto scale_block_space_offset =
b_block_space_offset + b_block_space_size_aligned;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size =
......@@ -551,7 +604,8 @@ struct GridwiseFpAintBGemm_Wmma
static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType));
b_block_space_size_aligned * sizeof(BDataType) +
scale_block_space_size_aligned * sizeof(ScaleDataType));
};
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
......@@ -609,7 +663,7 @@ struct GridwiseFpAintBGemm_Wmma
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = MakeBBlockDescriptor();
constexpr auto scale_block_desc = MakeBBlockDescriptor();
constexpr auto scale_block_desc = MakeScaleBlockDescriptor();
auto a_block_trait = [&](){
// A matrix blockwise copy
......@@ -768,7 +822,7 @@ struct GridwiseFpAintBGemm_Wmma
get_thread_local_1d_id() % 16,
0));
return make_tuple(b_block_buf, b_blockwise_copy, scale_blockwise_copy);
return make_tuple(b_block_buf, b_blockwise_copy);
}
};
......@@ -776,13 +830,14 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset,
SharedMemTrait::scale_block_space_size_aligned);
auto scale_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ck::tensor_operation::element_wise::PassThrough,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
......@@ -802,10 +857,10 @@ struct GridwiseFpAintBGemm_Wmma
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
1>(
NumGemmKPrefetchStage>(
scale_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
ck::tensor_operation::element_wise::PassThrough{},
b_element_op,
scale_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
......@@ -815,13 +870,12 @@ struct GridwiseFpAintBGemm_Wmma
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
auto scale_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto scale_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
ScaleDataType,
......@@ -894,9 +948,6 @@ struct GridwiseFpAintBGemm_Wmma
// gridwise GEMM pipeline
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
/*
scale_blockwise_copy
*/
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc,
a_blockwise_copy,
......
......@@ -12,6 +12,7 @@ enum struct PipelineVersion
{
v1,
v2,
dequant_v1,
};
template <PipelineVersion PipelineVer,
......@@ -36,6 +37,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
return GridwiseGemmPipeline_v2{};
}
else if constexpr(PipelineVer == PipelineVersion::dequant_v1)
{
return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{};
}
else
{
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
......
......@@ -600,9 +600,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
ScaleBlockTransfer& scale_blockwise_copy,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
......@@ -653,7 +653,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
}
}
};
......
......@@ -1090,6 +1090,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
// convert int8 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, int8_t>(int8_t x)
{
// TODO: replace it with fast_converter
float x_fp32 = static_cast<float>(x);
return type_convert<half_t>(x_fp32);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferencefpAintBGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
scale_k_n_{scale_k_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<ScaleDataType>& scale_k_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferencefpAintBGemm::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
ScaleDataType v_scale;
ADataType v_converted_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
// same for scale matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
arg.scale_k_n_(k, n));
}
else
{
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
}
v_converted_b = type_convert<ADataType>(v_b) * v_scale;
v_acc += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_converted_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances =
std::tuple<
using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
......
......@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances =
std::tuple<
using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
......
......@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment