Commit 760b0c75 authored by Jing Zhang's avatar Jing Zhang
Browse files

add wmma

parent 8831b0d8
......@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102" OR GPU_TARGETS MATCHES "gfx1200")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
......
......@@ -30,12 +30,13 @@ template <index_t BlockSize,
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto WmmaK = Number<16>{};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
// static constexpr auto WmmaK = Number<16>{};
static constexpr auto WmmaK = Number<8>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -141,6 +142,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
static_assert(WmmaK % A_K1 == 0, "");
static_assert(WmmaK % B_K1 == 0, "");
}
// Thread level, register decriptor. Vector-write
......@@ -154,6 +158,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
static_assert(MSubGroup == 1, "");
static_assert(NThreadPerSubGroup == 1, "");
static_assert(MAccVgprs == 8, "");
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
......@@ -224,6 +232,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_assert(0, "");
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
......
......@@ -411,7 +411,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_navi3_supported() || ck::is_navi4_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -49,7 +49,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -480,6 +480,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
static_assert(MSubGroup == 2, "");
static_assert(NThreadPerSubGroup == 16, "");
static_assert(MAccVgprs == 8, "");
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
......
......@@ -11,12 +11,15 @@ namespace ck {
enum struct WmmaInstr
{
// gfx11
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4,
// gfx12
wmma_f32_16x16x16_f16_gfx12,
};
/*
......@@ -117,6 +120,47 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
}
};
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
WaveSize,
......@@ -300,7 +344,11 @@ struct WmmaSelector
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
#if 1
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
#else
return WmmaInstr::wmma_f32_16x16x16_f16;
#endif
}
template <>
......@@ -397,6 +445,8 @@ struct WmmaGemm
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
static_assert(wmma_instr.num_acc_vgprs_per_wave == 8, "");
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
......@@ -477,7 +527,8 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return GetSwizzledLaneIdLow();
// return GetSwizzledLaneIdLow();
return GetLaneIdUnderSubGroup();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
......
......@@ -39,6 +39,31 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment