Commit f0bbc5db authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

[CK TILE] GEMM with packed i4

parent 0e5e29c4
......@@ -281,18 +281,18 @@ struct HostTensor
using Data = std::vector<T>;
template <typename X>
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size())
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
{
}
template <typename X, typename Y>
HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.get_element_space_size())
: mDesc(lens, strides), mData(get_element_space_size())
{
}
template <typename Lengths>
HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size())
HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
{
}
......@@ -302,7 +302,7 @@ struct HostTensor
{
}
HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {}
HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {}
template <typename OutT>
HostTensor<OutT> CopyAsType() const
......@@ -340,7 +340,11 @@ struct HostTensor
std::size_t get_element_size() const { return mDesc.get_element_size(); }
std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); }
std::size_t get_element_space_size() const
{
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.get_element_space_size() / PackedSize;
}
std::size_t get_element_space_size_in_bytes() const
{
......@@ -463,29 +467,27 @@ struct HostTensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
return mData[GetOffsetFromMultiIndex(is...)];
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
return mData[GetOffsetFromMultiIndex(is...)];
}
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
return mData[GetOffsetFromMultiIndex(idx)];
}
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_element_op(a_m_k(m, k));
BDataType v_b = b_element_op(b_k_n(k, n));
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
......@@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
......@@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,20 +9,166 @@
namespace ck_tile {
namespace element_wise {
#if 0
// Fast int4x4 to fp16x8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
return res;
}
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
return res;
}
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388616.f;
fp32_intermediates[1] -= 8388616.f;
fp32_intermediates[2] -= 8388616.f;
fp32_intermediates[3] -= 8388616.f;
bf16x4_t res;
res.lo = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
res.hi = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
}
struct PassThroughPack8
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_half4(bit_cast<int>(x));
y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
}
constexpr const static bool is_pack8_invocable = true;
};
struct DequantPack8
{
template <typename Y, typename X, typename Z>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
{
y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
}
constexpr const static bool is_pack8_invocable = true;
};
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
#if 0
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
y = type_convert<fp16x2_t>(t);
}
#endif
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
{
uint8_t x_u8 = bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
y.lo = type_convert<half_t>(x_l);
y.hi = type_convert<half_t>(x_h);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
......@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
......@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::BWarpDstrEncoding{}))>;
using AWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
using BWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
AWarpTileDistr{}))>;
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
BWarpTileDistr{}))>;
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
......@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
......@@ -105,6 +108,11 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
......@@ -208,6 +216,8 @@ struct BlockUniversalGemmAsBsCr
});
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
......@@ -217,10 +227,58 @@ struct BlockUniversalGemmAsBsCr
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
AWarpTensor a_warp_tile;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
AWarpTensor::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(a_warp_windows(mIter)(kIter));
static_assert(
GemmTraits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(a_warp_tile.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
}
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
BWarpTensor b_warp_tile;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
const auto in_dstr_tensors = load_tile(b_warp_windows(nIter)(kIter));
constexpr index_t thread_buffer_size =
BWarpTensor::get_thread_buffer_size() / UnaryOpSize;
static_assert(
GemmTraits::BWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_tile.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
}
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
......@@ -342,11 +400,59 @@ struct BlockUniversalGemmAsBsCr
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(a_warp_windows(mIter)(kIter));
static_assert(
GemmTraits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(a_warp_tiles_(mIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::BWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(b_warp_windows(nIter)(kIter));
static_assert(
GemmTraits::BWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_tiles_(nIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
});
});
}
......@@ -504,12 +610,59 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(a_warp_windows(mIter)(kIter));
static_assert(
GemmTraits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(a_warp_tiles_(mIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::BWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(b_warp_windows(nIter)(kIter));
static_assert(
GemmTraits::BWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_tiles_(nIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
});
});
}
......
......@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
......@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
......@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
......
......@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
constexpr index_t smem_size_a =
sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
constexpr index_t smem_size_b =
sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_b;
}
......@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
......
......@@ -13,14 +13,16 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_>
typename Traits_,
typename ComputeDataType_ = ADataType_>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
......@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: VectorLoadSize / sizeof(ADataType);
: PackedSize * VectorLoadSize / sizeof(ADataType);
}
else
{
......@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: VectorLoadSize / sizeof(BDataType);
: PackedSize * VectorLoadSize / sizeof(BDataType);
}
else
{
return VectorLoadSize / sizeof(BDataType);
return PackedSize * VectorLoadSize / sizeof(BDataType);
}
}
......@@ -143,9 +149,14 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
typename Traits_,
typename ComputeDataType_ = ADataType_>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
template <typename ADataType_,
typename BDataType_,
......@@ -154,14 +165,16 @@ template <typename ADataType_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
......
......@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0)
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (16 / sizeof(DataType));
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0)
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
return (PackedSize * 16 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
elements_per_thread % (4 / sizeof(DataType)) == 0)
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{
return (4 / sizeof(DataType));
return (PackedSize * 8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
elements_per_thread % (2 / sizeof(DataType)) == 0)
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{
return (2 / sizeof(DataType));
return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
}
else
{
return 1;
return PackedSize;
}
}
......@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment