Unverified Commit 2a261afc authored by jakpiase's avatar jakpiase Committed by GitHub
Browse files

Added structural sparsity blockwise gemm (#1435)



* Implemented smfmac xdlops

* Added smfmac blockwise xdlops

* fixes

* add reviewers suggestions

---------
Co-authored-by: default avatarAdam Osewski <19374865+aosewski@users.noreply.github.com>
parent d09572e8
...@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16> ...@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static constexpr index_t k_per_blk = 8; static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16> ...@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static constexpr index_t k_per_blk = 16; static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16> ...@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static constexpr index_t k_per_blk = 8; static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16> ...@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static constexpr index_t k_per_blk = 16; static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{ {
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c); intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
} }
}; };
...@@ -305,8 +329,8 @@ struct SparseXdlopsGemm ...@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"); "base base_type must be half or bfloat16!");
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
smfmac_instr.template run<MPerXdlops, NPerXdlops>( smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread); p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
}); });
} }
......
...@@ -9,16 +9,18 @@ namespace ck { ...@@ -9,16 +9,18 @@ namespace ck {
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32f16; struct intrin_smfmac_f32_16x16x32f16;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template <> template <>
struct intrin_smfmac_f32_16x16x32f16<16, 16> struct intrin_smfmac_f32_16x16x32f16<16, 16>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16; ...@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template <> template <>
struct intrin_smfmac_f32_16x16x32bf16<16, 16> struct intrin_smfmac_f32_16x16x32bf16<16, 16>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16; ...@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template <> template <>
struct intrin_smfmac_f32_32x32x16f16<32, 32> struct intrin_smfmac_f32_32x32x16f16<32, 32>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
...@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16; ...@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template <> template <>
struct intrin_smfmac_f32_32x32x16bf16<32, 32> struct intrin_smfmac_f32_32x32x16bf16<32, 32>
{ {
template <class FloatC> template <class FloatC, index_t abid = 0>
__device__ static void __device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else #else
ignore = reg_a; ignore = reg_a;
ignore = reg_b; ignore = reg_b;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment