Commit f0bbc5db authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

[CK TILE] GEMM with packed i4

parent 0e5e29c4
...@@ -281,18 +281,18 @@ struct HostTensor ...@@ -281,18 +281,18 @@ struct HostTensor
using Data = std::vector<T>; using Data = std::vector<T>;
template <typename X> template <typename X>
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size()) HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
{ {
} }
template <typename X, typename Y> template <typename X, typename Y>
HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides) HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.get_element_space_size()) : mDesc(lens, strides), mData(get_element_space_size())
{ {
} }
template <typename Lengths> template <typename Lengths>
HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size()) HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
{ {
} }
...@@ -302,7 +302,7 @@ struct HostTensor ...@@ -302,7 +302,7 @@ struct HostTensor
{ {
} }
HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {} HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {}
template <typename OutT> template <typename OutT>
HostTensor<OutT> CopyAsType() const HostTensor<OutT> CopyAsType() const
...@@ -340,7 +340,11 @@ struct HostTensor ...@@ -340,7 +340,11 @@ struct HostTensor
std::size_t get_element_size() const { return mDesc.get_element_size(); } std::size_t get_element_size() const { return mDesc.get_element_size(); }
std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); } std::size_t get_element_space_size() const
{
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.get_element_space_size() / PackedSize;
}
std::size_t get_element_space_size_in_bytes() const std::size_t get_element_space_size_in_bytes() const
{ {
...@@ -463,29 +467,27 @@ struct HostTensor ...@@ -463,29 +467,27 @@ struct HostTensor
template <typename... Is> template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
{ {
return mDesc.GetOffsetFromMultiIndex(is...); constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
} }
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; return mData[GetOffsetFromMultiIndex(is...)];
} }
template <typename... Is> template <typename... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; return mData[GetOffsetFromMultiIndex(is...)];
} }
T& operator()(std::vector<std::size_t> idx) T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
const T& operator()(std::vector<std::size_t> idx) const const T& operator()(std::vector<std::size_t> idx) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; return mData[GetOffsetFromMultiIndex(idx)];
} }
HostTensor<T> transpose(std::vector<size_t> axes = {}) const HostTensor<T> transpose(std::vector<size_t> axes = {}) const
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
ADataType v_a = a_element_op(a_m_k(m, k)); AccDataType v_a;
BDataType v_b = b_element_op(b_k_n(k, n)); AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
v_acc += {
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b); const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
...@@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType acc = 0.0; AccDataType acc = 0.0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout // Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k ? row * strideA + k
...@@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k ? col * strideB + k
: k * strideB + col; : k * strideB + col;
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]); AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
} }
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>) int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,20 +9,166 @@ ...@@ -9,20 +9,166 @@
namespace ck_tile { namespace ck_tile {
namespace element_wise { namespace element_wise {
#if 0 // Fast int4x4 to fp16x8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
return res;
}
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
return res;
}
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388616.f;
fp32_intermediates[1] -= 8388616.f;
fp32_intermediates[2] -= 8388616.f;
fp32_intermediates[3] -= 8388616.f;
bf16x4_t res;
res.lo = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
res.hi = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
}
struct PassThroughPack8
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_half4(bit_cast<int>(x));
y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
}
constexpr const static bool is_pack8_invocable = true;
};
struct DequantPack8
{
template <typename Y, typename X, typename Z>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
{
y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
}
constexpr const static bool is_pack8_invocable = true;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const #if 0
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
{ {
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<fp16x2_t>(t);
} }
#endif
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
{
uint8_t x_u8 = bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
y.lo = type_convert<half_t>(x_l);
y.hi = type_convert<half_t>(x_h);
}
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
#endif
struct PassThrough struct PassThrough
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr ...@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
template <typename PipelineProblem_, typename GemmPolicy_> template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_ struct GemmTraits_
{ {
using Problem = remove_cvref_t<PipelineProblem_>; using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>; using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
...@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution( using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::BWarpDstrEncoding{}))>; typename WarpGemm::BWarpDstrEncoding{}))>;
using AWarpTile = using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>; AWarpTileDistr{}))>;
using BWarpTile = using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>; BWarpTileDistr{}))>;
// TODO: Should we have two policies? Interwave & Intrawave ?? // TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t InterWaveSchedulingMacClusters = 1;
...@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
public: public:
using Traits = GemmTraits_<Problem_, Policy_>; using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>; using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>; using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>; using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>; using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
...@@ -105,6 +108,11 @@ struct BlockUniversalGemmAsBsCr ...@@ -105,6 +108,11 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto Scheduler = Traits::Scheduler; static constexpr auto Scheduler = Traits::Scheduler;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>; using I0 = number<0>;
using I1 = number<1>; using I1 = number<1>;
...@@ -208,6 +216,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -208,6 +216,8 @@ struct BlockUniversalGemmAsBsCr
}); });
using CWarpDstr = typename WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
...@@ -217,10 +227,58 @@ struct BlockUniversalGemmAsBsCr ...@@ -217,10 +227,58 @@ struct BlockUniversalGemmAsBsCr
// hot loop: // hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); AWarpTensor a_warp_tile;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
AWarpTensor::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(a_warp_windows(mIter)(kIter));
static_assert(
GemmTraits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(a_warp_tile.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
}
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); BWarpTensor b_warp_tile;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
const auto in_dstr_tensors = load_tile(b_warp_windows(nIter)(kIter));
constexpr index_t thread_buffer_size =
BWarpTensor::get_thread_buffer_size() / UnaryOpSize;
static_assert(
GemmTraits::BWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_tile.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
}
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -342,11 +400,59 @@ struct BlockUniversalGemmAsBsCr ...@@ -342,11 +400,59 @@ struct BlockUniversalGemmAsBsCr
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(a_warp_windows(mIter)(kIter));
static_assert(
GemmTraits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(a_warp_tiles_(mIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
}); });
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::BWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(b_warp_windows(nIter)(kIter));
static_assert(
GemmTraits::BWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_tiles_(nIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
}); });
}); });
} }
...@@ -504,12 +610,59 @@ struct BlockUniversalGemmAsBsCr ...@@ -504,12 +610,59 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window // TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window if constexpr(std::is_same_v<ADataType, pk_int4_t>)
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); {
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(a_warp_windows(mIter)(kIter));
static_assert(
GemmTraits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(a_warp_tiles_(mIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
}); });
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
GemmTraits::BWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(b_warp_windows(nIter)(kIter));
static_assert(
GemmTraits::BWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType =
ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_tiles_(nIter)(kIter)
.get_thread_buffer()
.template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer()
.template get_as<pk_int4x4_t>()[i]);
});
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
}); });
}); });
} }
......
...@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>; using CLayout = remove_cvref_t<typename Problem::CLayout>;
...@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// A/B split schedule // A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 constexpr auto num_ds_read_inst_a =
? A_LDS_Read_Inst_Num A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2; : A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 constexpr auto num_ds_read_inst_b =
? B_LDS_Read_Inst_Num B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2; : B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
...@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle = constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle = constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate = constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate = constexpr auto ds_read_b_mfma_rate =
......
...@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t PackedSize =
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
constexpr index_t smem_size_a =
sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t PackedSize =
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
constexpr index_t smem_size_b =
sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_b; return smem_size_b;
} }
...@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using AccDataType = float; using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::BDataType, typename Problem::ComputeDataType,
AccDataType, AccDataType,
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
......
...@@ -13,14 +13,16 @@ template <typename ADataType_, ...@@ -13,14 +13,16 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename Traits_> typename Traits_,
typename ComputeDataType_ = ADataType_>
struct GemmPipelineProblemBase struct GemmPipelineProblemBase
{ {
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
...@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase ...@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{ {
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{ {
constexpr index_t pixels_per_thread = constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType) return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread ? pixels_per_thread
: VectorLoadSize / sizeof(ADataType); : PackedSize * VectorLoadSize / sizeof(ADataType);
} }
else else
{ {
...@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase ...@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{ {
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t pixels_per_thread = constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType) return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread ? pixels_per_thread
: VectorLoadSize / sizeof(BDataType); : PackedSize * VectorLoadSize / sizeof(BDataType);
} }
else else
{ {
return VectorLoadSize / sizeof(BDataType); return PackedSize * VectorLoadSize / sizeof(BDataType);
} }
} }
...@@ -143,9 +149,14 @@ template <typename ADataType_, ...@@ -143,9 +149,14 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename Traits_> typename Traits_,
using GemmPipelineProblem = typename ComputeDataType_ = ADataType_>
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>; using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
...@@ -154,14 +165,16 @@ template <typename ADataType_, ...@@ -154,14 +165,16 @@ template <typename ADataType_,
typename Traits_, typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct UniversalGemmPipelineProblem struct UniversalGemmPipelineProblem
{ {
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
......
...@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy ...@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even! // Assume DataType is even!
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 && if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0) elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{ {
return (16 / sizeof(DataType)); return (PackedSize * 32 / sizeof(DataType));
} }
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 && else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0) elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{ {
return (8 / sizeof(DataType)); return (PackedSize * 16 / sizeof(DataType));
} }
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 && else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (4 / sizeof(DataType)) == 0) elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{ {
return (4 / sizeof(DataType)); return (PackedSize * 8 / sizeof(DataType));
} }
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 && else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
elements_per_thread % (2 / sizeof(DataType)) == 0) XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{ {
return (2 / sizeof(DataType)); return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
} }
else else
{ {
return 1; return PackedSize;
} }
} }
...@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::BDataType, typename Problem::ComputeDataType,
typename Problem::CDataType, typename Problem::CDataType,
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment