"...composable_kernel-1.git" did not exist on "81497a93a0840d5a1b5e84c1e47a90ae39d0fee6"
Unverified Commit 919aeb1f authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

[Navi3x-LWPCK-545] Block-wise GEMM + Real GEMM_WMMA_FP16 (#541)

* wmma_op + unit test

* add arch limitation to wmma test

* change arch limitation

* Refactor + Add all type unit test(int4 compile failed)

* Add f32_16x16x16_bf16 unit test

* tempsave

* tempsave

* tempsave

* runtime bug, cannot find symbol

* workaround for incorrect HIP warpSize return value

* debugging

* tempsave

* Correctness OK, waiting for optimization

* Tidy up + format

* temp save

* temp save, reproduce the v_bfi_b32 issue

* add inline asm for wmmaop test

* tidy up

* clean some debug purpose code

* discard some codes

* clang format

* clang format

* compiler issue fixed + increase tile size
parent 715e8dd2
...@@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) ...@@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
This diff is collapsed.
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace ck {
enum struct WmmaInstr
{
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
wmma_i32_16x16x16_iu4
};
/*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template <WmmaInstr Instr, index_t WaveSize, typename = void>
struct wmma_type
{
};
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_f16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_bf16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_i32_16x16x16_iu8_w64<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
}
};
template <typename src_type_a,
typename src_type_b,
typename dst_type,
index_t MPerWmma,
index_t NPerWmma>
struct WmmaSelector
{
template <typename src_type_a_,
typename src_type_b_,
typename dst_type_,
index_t MPerWmma_,
index_t NPerWmma_>
static constexpr auto GetWmma();
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{
return WmmaInstr::wmma_f16_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{
return WmmaInstr::wmma_bf16_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu8;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
#endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_wmma =
wmma_type<GetWmma<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
__host__ __device__ constexpr WmmaSelector()
{
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
};
template <typename src_type_a,
typename src_type_b,
typename dst_type,
index_t MPerWmma,
index_t NPerWmma,
index_t KPack,
bool TransposeC = false>
struct WmmaGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__host__ __device__ constexpr WmmaGemm()
{
static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
}
// WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 6>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma()
{
return wmma_instr.num_acc_vgprs_per_wave;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, half_t>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
is_same<dst_type, int32_t>::value)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
}
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
__device__ static auto GetSubGroupId()
{
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
}
__device__ static auto GetLaneIdUnderSubGroup()
{
return GetLaneId() % wmma_instr.num_thread_per_subgroups;
}
__device__ static auto GetSwizzledLaneIdLow()
{
return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return GetSwizzledLaneIdLow();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk()
{
index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
}
static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
}
};
} // namespace ck
...@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3); c3);
} }
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#ifndef CK_AMD_WMMA_HPP #ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp" #include "data_type.hpp"
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
// wave32 only /********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32; struct intrin_wmma_f32_16x16x16_f16_w32;
...@@ -19,8 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -19,8 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( // * Inline assembly need to elimate the duplicated data load, compiler won't help you
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); // delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
} }
}; };
...@@ -98,5 +105,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -98,5 +105,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
} }
}; };
/********************************WAVE64 MODE***********************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t, ...@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num> template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c) __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
{ {
__shared__ src_t p_shared[16 * 16 * 2];
const int lIdx = threadIdx.x; const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
...@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) ...@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
using src_vec = typename vector_type<src_t, 16>::type; using src_vec = typename vector_type<src_t, 16>::type;
src_vec a_frag = {}; src_vec a_frag = {};
src_vec b_frag = {}; src_vec b_frag = {};
src_vec a_temp = {};
src_vec b_temp = {};
// initialize c fragment to 0 // initialize c fragment to 0
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>; using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>;
acc_vec c_thread_buf_; acc_vec c_thread_buf_;
...@@ -111,21 +115,57 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) ...@@ -111,21 +115,57 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const int lane = lIdx % 16; const int lane = lIdx % 16;
const int lane_lo = lIdx / 2;
const int lane_hi = lIdx % 2;
for(int ele = 0; ele < 8; ++ele)
{
a_temp[ele] = a[8 * lane_hi + 16 * lane_lo + ele];
}
for(int ele = 0; ele < 8; ++ele)
{
b_temp[ele] = b[8 * lane_hi + 16 * lane_lo + ele];
}
__syncthreads();
for(int ele = 0; ele < 8; ++ele)
{
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele] = a_temp[ele];
}
for(int ele = 0; ele < 8; ++ele)
{
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele];
}
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
for(int ele = 0; ele < 16; ++ele) for(int ele = 0; ele < 16; ++ele)
{ {
b_frag[ele] = b[16 * lane + ele]; b_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8 + 16 * 16];
} }
// follow origin design // follow origin design
for(int ele = 0; ele < 16; ++ele) for(int ele = 0; ele < 16; ++ele)
{ {
a_frag[ele] = a[16 * lane + ele]; a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8];
} }
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
// sync threads, similar to mma_sync // sync threads, similar to mma_sync
__syncthreads(); // __syncthreads();
builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_); builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_);
// since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16
// when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag,
// c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType<float8_t>()(Number<0>{}));
__syncthreads(); __syncthreads();
// wait for results, similar to mma_sync // wait for results, similar to mma_sync
static_for<0, 8, 1>{}([&](auto ele) { static_for<0, 8, 1>{}([&](auto ele) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment