Unverified Commit 3737bb03 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Add bfp16/int8 support into XDL GEMM operator (#50)



* init StaticBufferV2

* clean

* adopt old output stage for staticBufferV2

* clean

* remove hack

* clean

* clean

* add parameters

* clean code

* move c_buffer alloc into blockwise gemm

* add adaptors for m/n_thread_data_on_grid

* tweak gemm

* adjust blockwise_gemm_xdlops

* tweak

* update conv

* update script

* adding bwd 1x1

* update script

* adding 1x1 bwd

* debugging bwd 1x1 failure

* update script

* update script

* test

* test v100

* add bf16_1k

* clang-format

* clean

* add bfp16 for gfx908

* add verification

* clean up

* clean code

* restore bfl16

* clean

* add bfp16 support into gemm_driver

* apply new generator to other drivers

* add int8 support

* cleanb

* clean

* clean

* clean
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
Co-authored-by: default avatarChao Liu <lc.roy86@gmail.com>
Co-authored-by: default avatarroot <root@hayabusa6111.amd.com>
parent b491ebf3
...@@ -12,18 +12,19 @@ enum struct MfmaInstr ...@@ -12,18 +12,19 @@ enum struct MfmaInstr
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32, mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32, mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction mfma_f32_32x32x2xf32,
mfma_f32_16x16x4xf32, // k reduction mfma_f32_16x16x4xf32,
mfma_f32_32x32x4f16, mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16, mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16, mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction mfma_f32_32x32x8f16,
mfma_f32_16x16x16f16, // k reduction mfma_f32_16x16x16f16,
mfma_f32_32x32x2bf16, mfma_f32_32x32x8bf16_1k,
mfma_f32_16x16x2bf16, mfma_f32_16x16x16bf16_1k,
mfma_f32_4x4x2bf16, mfma_f32_32x32x4bf16,
mfma_f32_32x32x4bf16, // k reduction mfma_f32_16x16x8bf16,
mfma_f32_16x16x8bf16, // k reduction mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8,
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16> ...@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
} }
}; };
#if 0
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16> struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4; static constexpr index_t num_groups_per_blk = 4;
...@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16> ...@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
static constexpr index_t num_threads_per_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_f32_32x32x8bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b); }
};
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run( template <>
p_a, p_b, reg_c); struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x16bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16> ...@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_f32_32x32x4bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
} }
}; };
...@@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16> ...@@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_f32_16x16x8bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
} }
}; };
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16> struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 16; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16; static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 16; static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
} }
}; };
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16> struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4; static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 64; static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops, __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{ {
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a); intrin_mfma_i32_16x16x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector struct MfmaSelector
...@@ -498,73 +473,37 @@ struct MfmaSelector ...@@ -498,73 +473,37 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
#if 0
template <>
static constexpr auto GetMfma<ushort, 128, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 128>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 32, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 64, 16>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <> template <>
static constexpr auto GetMfma<ushort, 16, 64>() static constexpr auto GetMfma<ushort, 32, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetMfma<ushort, 8, 64>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{}; #if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
} }
template <> template <>
static constexpr auto GetMfma<ushort, 4, 64>() static constexpr auto GetMfma<ushort, 16, 16>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{}; #if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
} }
template <> template <>
static constexpr auto GetMfma<ushort, 32, 32>() static constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{}; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<ushort, 16, 16>() static constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{}; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
...@@ -686,8 +625,8 @@ struct XdlopsGemm ...@@ -686,8 +625,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value, is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value,
"base base_type must be float, half, ushort!"); "base base_type must be float, half, ushort, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
......
...@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, ...@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
__device__ int16_t __device__ ushort
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
__device__ ushort2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
__device__ ushort4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
__device__ int32_t __device__ int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, ...@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
__device__ void __device__ void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return as_type<half8_t>(tmp); return as_type<half8_t>(tmp);
} }
} }
else if constexpr(is_same<T, ushort>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<ushort8_t>(tmp);
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
...@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -552,6 +606,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -552,6 +606,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0); 0);
} }
} }
else if constexpr(is_same<T, ushort>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace ck { namespace ck {
// A, B, C, cbsz, abid, blgp // A, B, C, cbsz, abid, blgp
// fp32
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
...@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( ...@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
// fp16
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
...@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( ...@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
// bfp16
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
...@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( ...@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
// int8
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
// fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
...@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64> ...@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
} }
}; };
// fp16
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4f16; struct intrin_mfma_f32_32x32x4f16;
...@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> ...@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
} }
}; };
#if 0 // bfp16
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2bf16; struct intrin_mfma_f32_32x32x8bf16_1k;
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
{ {
__device__ static c_vec32_4_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.template AsType<float16_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
} }
}; };
template <index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> struct intrin_mfma_f32_16x16x16bf16_1k;
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
{ {
__device__ static c_vec32_2_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.template AsType<float4_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
return reg_c;
} }
}; };
template <index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> struct intrin_mfma_f32_32x32x4bf16;
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride> template <>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> struct intrin_mfma_f32_32x32x4bf16<32, 32>
{ {
__device__ static c_vec32_1_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
return reg_c; reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, struct intrin_mfma_f32_16x16x8bf16;
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <> template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, struct intrin_mfma_f32_16x16x8bf16<16, 16>
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); template <class FloatC>
return reg_c; __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
} {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16; struct intrin_mfma_i32_32x32x8i8;
template <> template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64> struct intrin_mfma_i32_32x32x8i8<32, 32>
{ {
__device__ static c_vec4_1_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c.template AsType<int32x16_t>()(Number<0>{}) =
return reg_c; llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a),
as_type<int>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;
template <> template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64> struct intrin_mfma_i32_16x16x16i8<16, 16>
{ {
__device__ static c_vec4_2_t::VecType template <class FloatC>
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c.template AsType<int32x4_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a),
return reg_c; as_type<int>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -31,7 +31,7 @@ extern "C" { ...@@ -31,7 +31,7 @@ extern "C" {
#endif #endif
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
#define EXECUTION_SPECIFIER __device__ #define EXECUTION_SPECIFIER __device__ __host__
#else #else
#define EXECUTION_SPECIFIER #define EXECUTION_SPECIFIER
#endif // MIOPEN_BACKEND_HIP #endif // MIOPEN_BACKEND_HIP
......
...@@ -325,30 +325,30 @@ int main(int argc, char* argv[]) ...@@ -325,30 +325,30 @@ int main(int argc, char* argv[])
// no initialization // no initialization
break; break;
case 1: case 1:
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break; break;
case 2: case 2:
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break; break;
case 3: case 3:
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break; break;
case 4: case 4:
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break; break;
case 5: case 5:
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread); out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
break; break;
default: default:
out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{1, 5}, num_thread);
auto gen_wei = [](auto... is) { auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
}; };
wei.GenerateTensorValue(gen_wei, num_thread); wei.GenerateTensorValue(gen_wei, num_thread);
} }
......
...@@ -80,13 +80,29 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -80,13 +80,29 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3]) wi < in.mDesc.GetLengths()[3])
{ {
v += static_cast<const double>(in(n, c, hi, wi)) * if constexpr(is_same<TIn, ushort>::value)
static_cast<const double>(wei(k, c, y, x)); {
v += bfloat16_to_float(in(n, c, hi, wi)) *
bfloat16_to_float(wei(k, c, y, x));
}
else
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
} }
} }
} }
} }
out(n, k, ho, wo) = v;
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = float_to_bfloat16(v);
}
else
{
out(n, k, ho, wo) = v;
}
}; };
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
...@@ -102,13 +118,28 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -102,13 +118,28 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2]) wi < in.mDesc.GetLengths()[2])
{ {
v += static_cast<const double>(in(n, hi, wi, c)) * if constexpr(is_same<TIn, ushort>::value)
static_cast<const double>(wei(k, y, x, c)); {
v += bfloat16_to_float(in(n, hi, wi, c)) *
bfloat16_to_float(wei(k, y, x, c));
}
else
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
} }
} }
} }
} }
out(n, ho, wo, k) = v; if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = float_to_bfloat16(v);
}
else
{
out(n, ho, wo, k) = v;
}
}; };
if(layout == ConvTensorLayout::NCHW) if(layout == ConvTensorLayout::NCHW)
...@@ -226,10 +257,14 @@ int main(int argc, char* argv[]) ...@@ -226,10 +257,14 @@ int main(int argc, char* argv[])
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1 #elif 0
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 1
using in_data_t = ushort;
using acc_data_t = float;
using out_data_t = ushort;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
using acc_data_t = int32_t; using acc_data_t = int32_t;
...@@ -295,30 +330,30 @@ int main(int argc, char* argv[]) ...@@ -295,30 +330,30 @@ int main(int argc, char* argv[])
// no initialization // no initialization
break; break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break; break;
case 2: case 2:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break; break;
case 3: case 3:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break; break;
case 4: case 4:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break; break;
case 5: case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread); in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
auto gen_wei = [](auto... is) { auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
}; };
wei.GenerateTensorValue(gen_wei, num_thread); wei.GenerateTensorValue(gen_wei, num_thread);
} }
......
...@@ -297,30 +297,30 @@ int main(int argc, char* argv[]) ...@@ -297,30 +297,30 @@ int main(int argc, char* argv[])
// no initialization // no initialization
break; break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
break; break;
case 2: case 2:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
break; break;
case 3: case 3:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
break; break;
case 4: case 4:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
break; break;
case 5: case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread); in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.1, 0.1}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread); out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{-0.1, 0.1}, num_thread);
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
auto gen_out = [](auto... is) { auto gen_out = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); return GeneratorTensor_2<out_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
}; };
out.GenerateTensorValue(gen_out, num_thread); out.GenerateTensorValue(gen_out, num_thread);
} }
......
...@@ -239,10 +239,14 @@ int main(int argc, char* argv[]) ...@@ -239,10 +239,14 @@ int main(int argc, char* argv[])
using ab_data_t = float; using ab_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using c_data_t = float; using c_data_t = float;
#elif 1 #elif 0
using ab_data_t = half_t; using ab_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using c_data_t = half_t; using c_data_t = half_t;
#elif 1
using ab_data_t = ushort;
using acc_data_t = float;
using c_data_t = ushort;
#elif 1 #elif 1
using ab_data_t = int8_t; using ab_data_t = int8_t;
using acc_data_t = int32_t; using acc_data_t = int32_t;
...@@ -321,24 +325,24 @@ int main(int argc, char* argv[]) ...@@ -321,24 +325,24 @@ int main(int argc, char* argv[])
// no initialization // no initialization
break; break;
case 1: case 1:
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
break; break;
case 2: case 2:
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
break; break;
case 3: case 3:
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
break; break;
case 4: case 4:
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
break; break;
default: default:
a.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread); a.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{0.0, 1.0}, num_thread);
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread); b.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{-0.5, 0.5}, num_thread);
} }
#if USE_GEMM_XDL_MK_KN_MN #if USE_GEMM_XDL_MK_KN_MN
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <>
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
const Tensor<ushort>& b,
Tensor<ushort>& c,
const GemmMatrixLayout layout)
{
if(layout == GemmMatrixLayout::MK_KN_MN)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
}
c(m, n) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
{
auto f_mk_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
}
c(m, n) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
{
auto f_km_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
}
c(m, n) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
auto f_km_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
}
c(m, n) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
}
c(n, m) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
}
c(n, m) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
}
c(n, m) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
}
c(n, m) = float_to_bfloat16(v);
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
template <typename AType, typename BType, typename CType> template <typename AType, typename BType, typename CType>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n, const Tensor<BType>& b_k_n,
......
...@@ -321,4 +321,41 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -321,4 +321,41 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
} }
float bf16_to_f32(ushort src_val)
{
typedef union
{
ushort x, y;
float f32;
} bf16_f32_t;
bf16_f32_t v;
v.x = 0;
v.y = src_val;
return v.f32;
}
template <>
void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result)
{
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i)
{
error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
if(max_diff < diff)
{
max_diff = diff;
ref_value = bf16_to_f32(ref.mData[i]);
result_value = bf16_to_f32(result.mData[i]);
}
}
std::cout << "error: " << error << std::endl;
std::cout << "max_diff: " << max_diff << ", ref: " << ref_value << ", res: " << result_value
<< std::endl;
}
#endif #endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cmath> #include <cmath>
#include "config.hpp" #include "config.hpp"
template <typename T>
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
int value = 1; int value = 1;
...@@ -15,6 +16,30 @@ struct GeneratorTensor_1 ...@@ -15,6 +16,30 @@ struct GeneratorTensor_1
} }
}; };
template <>
struct GeneratorTensor_1<ushort>
{
float value = 1.0;
template <typename... Is>
ushort operator()(Is...)
{
return float_to_bfloat16(value);
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
int8_t value = 1;
template <typename... Is>
int8_t operator()(Is...)
{
return value;
}
};
struct GeneratorTensor_0 struct GeneratorTensor_0
{ {
int value = 0; int value = 0;
...@@ -26,6 +51,7 @@ struct GeneratorTensor_0 ...@@ -26,6 +51,7 @@ struct GeneratorTensor_0
} }
}; };
template <typename T>
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
int min_value = 0; int min_value = 0;
...@@ -38,6 +64,33 @@ struct GeneratorTensor_2 ...@@ -38,6 +64,33 @@ struct GeneratorTensor_2
} }
}; };
template <>
struct GeneratorTensor_2<ushort>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ushort operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return float_to_bfloat16(tmp);
}
};
template <>
struct GeneratorTensor_2<int8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
int8_t operator()(Is...)
{
return (std::rand() % (max_value - min_value)) + min_value;
}
};
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -53,6 +106,39 @@ struct GeneratorTensor_3 ...@@ -53,6 +106,39 @@ struct GeneratorTensor_3
} }
}; };
template <>
struct GeneratorTensor_3<ushort>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ushort operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return float_to_bfloat16(fp32_tmp);
}
};
template <>
struct GeneratorTensor_3<int8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
int8_t operator()(Is...)
{
int8_t min_tmp = static_cast<int8_t>(min_value);
int8_t max_tmp = static_cast<int8_t>(max_value);
return (std::rand() % (max_tmp - min_tmp)) + min_tmp;
}
};
struct GeneratorTensor_Checkboard struct GeneratorTensor_Checkboard
{ {
template <typename... Ts> template <typename... Ts>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment