Commit 0c51a35e authored by aska-0096's avatar aska-0096
Browse files

fpAintB kernel compile pass

parent febd76e4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp"
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
...@@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ ...@@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
BElementOp, BElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
2, // Prefetch stage 1, // Prefetch stage
128, // BlockSize 128, // BlockSize
128, // MPerBlock 128, // MPerBlock
64, // NPerBlock 64, // NPerBlock
...@@ -67,8 +67,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ ...@@ -67,8 +67,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
8>; 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
ScaleDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
......
...@@ -27,6 +27,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -27,6 +27,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
// assume scale tensor is [1, n]
Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, BLayout{}));
switch(config.init_method) switch(config.init_method)
{ {
...@@ -34,26 +36,32 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -34,26 +36,32 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-5.f, 5.f}(scale_k_n);
break; break;
case 2: case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
break; break;
case 3: case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-5.f, 5.f}(scale_k_n);
break; break;
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{1.f, 1.f}(scale_k_n);
break; break;
case 5: case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-2.f, 2.f}(scale_k_n);
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
} }
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
...@@ -61,6 +69,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -61,6 +69,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
...@@ -77,10 +86,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -77,10 +86,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
#else #else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
scale_k_n_device_buf.ToDevice(scale_k_n.mData.data());
#endif #endif
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -98,6 +109,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -98,6 +109,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
#else #else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(scale_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif #endif
M, M,
...@@ -136,7 +148,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -136,7 +148,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, scale_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -20,7 +20,7 @@ template <index_t BlockSize, ...@@ -20,7 +20,7 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename ABlockDesc, typename ABlockDesc,
typename BBlockDesc, typename BBlockDesc,
typename ScaleDesc, typename ScaleBlockDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -73,8 +73,9 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -73,8 +73,9 @@ struct Blockwise_fpAintB_GemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
// As Float DataType
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<ADataType, BDataType, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<ADataType, ADataType, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
...@@ -178,9 +179,10 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -178,9 +179,10 @@ struct Blockwise_fpAintB_GemmWMMA
} }
using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__
Blockwise_fpAintB_GemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex()) Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin) : a_thread_copy_(a_origin), b_thread_copy_(b_origin), scale_thread_copy_(b_origin)
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
...@@ -290,8 +292,12 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -290,8 +292,12 @@ struct Blockwise_fpAintB_GemmWMMA
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
static constexpr ScaleBlockDesc scale_block_desc_1_n0_n1_n2_1;
template <typename ABlockBuffer, typename BBlockBuffer, typename ScaleBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer,
typename BBlockBuffer,
typename ScaleBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
const ScaleBlockBuffer& scale_block_buf, const ScaleBlockBuffer& scale_block_buf,
...@@ -305,8 +311,6 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -305,8 +311,6 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_desc_.GetElementSpaceSize()); scale_thread_desc_.GetElementSpaceSize());
auto converted_b_thread_buf = b_thread_buf; auto converted_b_thread_buf = b_thread_buf;
static constexpr auto dequantizer = Dequantizer<ADataType, BDataType>{};
// basic intrinsic to determine loopover direction // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat) if constexpr(MRepeat < NRepeat)
{ {
...@@ -333,21 +337,22 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -333,21 +337,22 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_buf); b_thread_buf);
// read weight scale // read weight scale
scale_thread_copy_.Run( scale_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, scale_block_buf,
b_scale_thread_desc_, scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_scale_thread_buf); scale_thread_buf);
// convert B from int8 to fp16 // convert B from int8 to fp16, multiply scale
converted_b_thread_buf = type_convert(b_thread_buf); static_for<0, b_thread_buf.size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) =
// multiply scale scale_thread_buf[i / WmmaK] *
dequantize(converted_b_thread_buf, scale_thread_buf); type_convert<ADataType>(b_thread_buf[i]);
});
vector_type<ADataType, WmmaK> a_thread_vec; vector_type<ADataType, WmmaK> a_thread_vec;
vector_type<BDataType, WmmaK> b_thread_vec; vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<ADataType>()(i) = a_thread_vec.template AsType<ADataType>()(i) =
...@@ -358,7 +363,7 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -358,7 +363,7 @@ struct Blockwise_fpAintB_GemmWMMA
(i / A_K1) % A_KRow, (i / A_K1) % A_KRow,
0, 0,
i % A_K1))>{}]; i % A_K1))>{}];
b_thread_vec.template AsType<BDataType>()(i) = b_thread_vec.template AsType<ADataType>()(i) =
converted_b_thread_buf[Number<b_thread_desc_.CalculateOffset( converted_b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, make_tuple(i / B_K1 / B_KRow,
n0, n0,
...@@ -369,7 +374,7 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -369,7 +374,7 @@ struct Blockwise_fpAintB_GemmWMMA
}); });
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type; using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<BDataType, WmmaK>::type; using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -396,6 +401,20 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -396,6 +401,20 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
// convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] *
type_convert<ADataType>(b_thread_buf[i]);
});
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
...@@ -406,11 +425,11 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -406,11 +425,11 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_buf); a_thread_buf);
vector_type<ADataType, WmmaK> a_thread_vec; vector_type<ADataType, WmmaK> a_thread_vec;
vector_type<BDataType, WmmaK> b_thread_vec; vector_type<ADataType, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<BDataType>()(i) = b_thread_vec.template AsType<ADataType>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( converted_b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, make_tuple(i / B_K1 / B_KRow,
n0, n0,
0, 0,
...@@ -428,7 +447,7 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -428,7 +447,7 @@ struct Blockwise_fpAintB_GemmWMMA
}); });
using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type; using wmma_input_type_a = typename vector_type<ADataType, WmmaK>::type;
using wmma_input_type_b = typename vector_type<BDataType, WmmaK>::type; using wmma_input_type_b = typename vector_type<ADataType, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -472,6 +491,15 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -472,6 +491,15 @@ struct Blockwise_fpAintB_GemmWMMA
Number<B_K1>{}, Number<B_K1>{},
Number<1>{})); Number<1>{}));
static constexpr auto scale_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(I0, I1, I0, I0, I0, I0));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
...@@ -548,8 +576,42 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -548,8 +576,42 @@ struct Blockwise_fpAintB_GemmWMMA
TransposeC ? true : false>; TransposeC ? true : false>;
}; };
template <bool EnableLds>
struct ScaleThreadCopySelector;
template <>
struct ScaleThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct ScaleThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic<
ScaleDataType,
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_; typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
typename ScaleThreadCopySelector<BEnableLds>::type scale_thread_copy_;
}; };
} // namespace ck } // namespace ck
...@@ -66,7 +66,7 @@ template <typename ALayout, ...@@ -66,7 +66,7 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::dequant_v1>
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -183,7 +183,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -183,7 +183,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
// When K = 1, it might be scale tensor. // When K = 1, it might be scale tensor.
assert(K % K1 == 0 && K != 1 ); assert(K % K1 == 0 && K != 1);
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
...@@ -241,8 +241,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -241,8 +241,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm = GridwiseFpAintBGemm_Wmma<
GridwiseFpAintBGemm_Wmma<BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ScaleDataType, ScaleDataType,
...@@ -252,6 +252,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -252,6 +252,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc, AGridDesc,
BGridDesc, BGridDesc,
ScaleGridDesc,
CGridDesc_M_N, CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -295,7 +296,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -295,7 +296,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
const ScaleDataType* p_scale, const ScaleDataType* p_scale_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
index_t M, index_t M,
index_t N, index_t N,
...@@ -310,7 +311,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -310,7 +311,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_scale_grid_{p_scale}, p_scale_grid_{p_scale_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_{}, a_grid_desc_{},
b_grid_desc_{}, b_grid_desc_{},
...@@ -329,7 +330,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -329,7 +330,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{ {
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(1, N, 1); scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, 0);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
...@@ -347,7 +348,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -347,7 +348,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const ScaleDataType* p_b_grid_; const ScaleDataType* p_scale_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc a_grid_desc_; AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_; BGridDesc b_grid_desc_;
...@@ -406,7 +407,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -406,7 +407,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc>, remove_reference_t<DeviceOp::AGridDesc>,
remove_reference_t<DeviceOp::BGridDesc>, remove_reference_t<DeviceOp::BGridDesc>,
remove_reference_t<DeviceOp::BGridDesc>, remove_reference_t<DeviceOp::ScaleGridDesc>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
...@@ -422,9 +423,11 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -422,9 +423,11 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_scale_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_, arg.a_grid_desc_,
arg.b_grid_desc_, arg.b_grid_desc_,
arg.scale_grid_desc_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -536,10 +539,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -536,10 +539,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
} }
} }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, return GridwiseGemm::CheckValidity(
arg.b_grid_desc_, arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_);
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
} }
// polymorphic // polymorphic
...@@ -550,6 +551,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -550,6 +551,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
const ScaleDataType* p_scale,
CDataType* p_c, CDataType* p_c,
index_t M, index_t M,
index_t N, index_t N,
...@@ -563,6 +565,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -563,6 +565,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_scale,
p_c, p_c,
M, M,
N, N,
...@@ -582,6 +585,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -582,6 +585,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_scale,
void* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
...@@ -595,6 +599,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -595,6 +599,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<const ScaleDataType*>(p_scale),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
M, M,
N, N,
...@@ -623,8 +628,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -623,8 +628,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std::map<LoopScheduler, std::string> LoopSchedToString{ std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"}, std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v2, "v2"}}; {PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"},
{PipelineVersion::dequant_v1, "dequant_v1"}};
// clang-format off // clang-format off
str << "DeviceFpAintBGemm_Wmma_CShuffle" str << "DeviceFpAintBGemm_Wmma_CShuffle"
......
...@@ -20,9 +20,11 @@ namespace ck { ...@@ -20,9 +20,11 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename ScaleDataType,
typename CDataType, typename CDataType,
typename AGridDesc, typename AGridDesc,
typename BGridDesc, typename BGridDesc,
typename ScaleGridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -33,11 +35,13 @@ __global__ void ...@@ -33,11 +35,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
const ScaleDataType* __restrict__ p_scale_grid,
CDataType* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc, const AGridDesc a_grid_desc,
const BGridDesc b_grid_desc, const BGridDesc b_grid_desc,
const ScaleGridDesc scale_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -51,10 +55,12 @@ __global__ void ...@@ -51,10 +55,12 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_scale_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
a_grid_desc, a_grid_desc,
b_grid_desc, b_grid_desc,
scale_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -63,9 +69,11 @@ __global__ void ...@@ -63,9 +69,11 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_scale_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc; ignore = a_grid_desc;
ignore = b_grid_desc; ignore = b_grid_desc;
ignore = scale_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -77,12 +85,14 @@ __global__ void ...@@ -77,12 +85,14 @@ __global__ void
template <index_t BlockSize, template <index_t BlockSize,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename ScaleDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename CDataType, typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc, typename AGridDesc,
typename BGridDesc, typename BGridDesc,
typename ScaleGridDesc,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -119,7 +129,7 @@ template <index_t BlockSize, ...@@ -119,7 +129,7 @@ template <index_t BlockSize,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::dequant_v1>
struct GridwiseFpAintBGemm_Wmma struct GridwiseFpAintBGemm_Wmma
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -140,7 +150,12 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -140,7 +150,12 @@ struct GridwiseFpAintBGemm_Wmma
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1_dequant<NumGemmKPrefetchStage, AEnableLds, BEnableLds>; using GridwiseGemmPipe =
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopSched,
AEnableLds,
BEnableLds>())>;
// Describe how data store to (LDS/VGPR) buffer from Global memory // Describe how data store to (LDS/VGPR) buffer from Global memory
__host__ __device__ static constexpr auto MakeABlockDescriptor() __host__ __device__ static constexpr auto MakeABlockDescriptor()
...@@ -237,6 +252,38 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -237,6 +252,38 @@ struct GridwiseFpAintBGemm_Wmma
return b_block_desc; return b_block_desc;
} }
__host__ __device__ static constexpr auto MakeScaleBlockDescriptor()
{
// Scale [1, N], all K related dimension reduce to 1
constexpr auto scale_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(I0, I1, I0));
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(I0, I1, I0, I0, I0, I0, I0));
}
}();
return scale_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep() __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{ {
constexpr auto a_block_copy_step = [&]() { constexpr auto a_block_copy_step = [&]() {
...@@ -537,9 +584,15 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -537,9 +584,15 @@ struct GridwiseFpAintBGemm_Wmma
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align) max_lds_align)
: 0; : 0;
static constexpr auto scale_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
MakeScaleBlockDescriptor().GetElementSpaceSize(), max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned; static constexpr auto b_block_space_offset = a_block_space_size_aligned;
static constexpr auto scale_block_space_offset =
b_block_space_offset + b_block_space_size_aligned;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size = static constexpr auto c_shuffle_block_space_size =
...@@ -551,7 +604,8 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -551,7 +604,8 @@ struct GridwiseFpAintBGemm_Wmma
static constexpr auto lds_size = static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) + a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType)); b_block_space_size_aligned * sizeof(BDataType) +
scale_block_space_size_aligned * sizeof(ScaleDataType));
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
...@@ -609,7 +663,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -609,7 +663,7 @@ struct GridwiseFpAintBGemm_Wmma
constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = MakeBBlockDescriptor(); constexpr auto b_block_desc = MakeBBlockDescriptor();
constexpr auto scale_block_desc = MakeBBlockDescriptor(); constexpr auto scale_block_desc = MakeScaleBlockDescriptor();
auto a_block_trait = [&](){ auto a_block_trait = [&](){
// A matrix blockwise copy // A matrix blockwise copy
...@@ -768,7 +822,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -768,7 +822,7 @@ struct GridwiseFpAintBGemm_Wmma
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
return make_tuple(b_block_buf, b_blockwise_copy, scale_blockwise_copy); return make_tuple(b_block_buf, b_blockwise_copy);
} }
}; };
...@@ -776,13 +830,14 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -776,13 +830,14 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset, static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset,
SharedMemTrait::scale_block_space_size_aligned); SharedMemTrait::scale_block_space_size_aligned);
auto scale_blockwise_copy = auto scale_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ck::tensor_operation::element_wise::PassThrough, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
...@@ -802,10 +857,10 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -802,10 +857,10 @@ struct GridwiseFpAintBGemm_Wmma
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
1>( NumGemmKPrefetchStage>(
scale_grid_desc, scale_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
ck::tensor_operation::element_wise::PassThrough{}, b_element_op,
scale_block_desc, scale_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
...@@ -815,13 +870,12 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -815,13 +870,12 @@ struct GridwiseFpAintBGemm_Wmma
else else
{ {
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value; constexpr auto K0PerWmma = WmmaK/2/K1Value;
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
auto scale_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>( auto scale_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_block_desc.GetElementSpaceSize()); scale_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto scale_blockwise_copy = auto scale_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
ScaleDataType, ScaleDataType,
...@@ -894,9 +948,6 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -894,9 +948,6 @@ struct GridwiseFpAintBGemm_Wmma
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
/*
scale_blockwise_copy
*/
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc, a_block_desc,
a_blockwise_copy, a_blockwise_copy,
......
...@@ -12,6 +12,7 @@ enum struct PipelineVersion ...@@ -12,6 +12,7 @@ enum struct PipelineVersion
{ {
v1, v1,
v2, v2,
dequant_v1,
}; };
template <PipelineVersion PipelineVer, template <PipelineVersion PipelineVer,
...@@ -36,6 +37,10 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -36,6 +37,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{ {
return GridwiseGemmPipeline_v2{}; return GridwiseGemmPipeline_v2{};
} }
else if constexpr(PipelineVer == PipelineVersion::dequant_v1)
{
return GridwiseGemmPipeline_v1_dequant<NumPrefetch, AEnableLds, BEnableLds>{};
}
else else
{ {
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
......
...@@ -600,9 +600,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> ...@@ -600,9 +600,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc, const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc, const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf, const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf, ScaleBlockBuffer& scale_block_buf,
ScaleBlockTransfer& scale_blockwise_copy,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
...@@ -653,7 +653,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> ...@@ -653,7 +653,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{ {
block_sync_lds(); block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
} }
} }
}; };
......
...@@ -1090,6 +1090,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -1090,6 +1090,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
// convert int8 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, int8_t>(int8_t x)
{
// TODO: replace it with fast_converter
float x_fp32 = static_cast<float>(x);
return type_convert<half_t>(x_fp32);
}
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x); __host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferencefpAintBGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
scale_k_n_{scale_k_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<ScaleDataType>& scale_k_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferencefpAintBGemm::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
ScaleDataType v_scale;
ADataType v_converted_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
// same for scale matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
arg.scale_k_n_(k, n));
}
else
{
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
}
v_converted_b = type_convert<ADataType>(v_b) * v_scale;
v_acc += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_converted_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = std::tuple<
std::tuple<
// clang-format off // clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
......
...@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances = using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances = std::tuple<
std::tuple<
// clang-format off // clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
......
...@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,8 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances = using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances = std::tuple<
std::tuple<
// clang-format off // clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment