Commit da0bb989 authored by ltqin's avatar ltqin
Browse files

add bfloat16 to xdlops

parent 36ca02f3
...@@ -62,7 +62,7 @@ using QKVElementOp = PassThrough; ...@@ -62,7 +62,7 @@ using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using DataType = F16;
using GemmDataType = F16; using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
......
...@@ -40,6 +40,12 @@ struct PassThrough ...@@ -40,6 +40,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<bfloat16_t, bfloat16_t>(bfloat16_t& y, const bfloat16_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const __host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{ {
......
...@@ -524,6 +524,27 @@ struct MfmaSelector ...@@ -524,6 +524,27 @@ struct MfmaSelector
#endif #endif
} }
template <>
static constexpr auto GetMfma<bfloat16_t, 32, 32>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}
template <>
static constexpr auto GetMfma<bfloat16_t, 16, 16>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() static constexpr auto GetMfma<int8_t, 32, 32>()
{ {
...@@ -735,7 +756,7 @@ struct XdlopsGemm ...@@ -735,7 +756,7 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value|| is_same<base_type, bfloat16_t>::value ||
is_same<base_type, int8_t>::value, is_same<base_type, int8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
......
...@@ -215,6 +215,13 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> ...@@ -215,6 +215,13 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
template <class FloatC>
__device__ static void Run(const bfloat16x4_t& reg_a, const bfloat16x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
}
}; };
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
...@@ -243,6 +250,12 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32> ...@@ -243,6 +250,12 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
template <class FloatC>
__device__ static void Run(const bfloat16x2_t& reg_a, const bfloat16x2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
}
}; };
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
......
...@@ -13,6 +13,8 @@ using half_t = _Float16; ...@@ -13,6 +13,8 @@ using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
#endif #endif
using bfloat16_t = int16_t;
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
...@@ -119,6 +121,13 @@ struct scalar_type<bhalf_t> ...@@ -119,6 +121,13 @@ struct scalar_type<bhalf_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
template <>
struct scalar_type<bfloat16_t>
{
using type = bfloat16_t;
static constexpr index_t vector_size = 1;
};
template <> template <>
struct scalar_type<int32_t> struct scalar_type<int32_t>
{ {
...@@ -926,6 +935,13 @@ using bhalf16_t = typename vector_type<bhalf_t, 16>::type; ...@@ -926,6 +935,13 @@ using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type; using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type; using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// bfloat16_t
using bfloat16x2_t = typename vector_type<bfloat16_t, 2>::type;
using bfloat16x4_t = typename vector_type<bfloat16_t, 4>::type;
using bfloat16x8_t = typename vector_type<bfloat16_t, 8>::type;
using bfloat16x16_t = typename vector_type<bfloat16_t, 16>::type;
using bfloat16x32_t = typename vector_type<bfloat16_t, 32>::type;
using bfloat16x64_t = typename vector_type<bfloat16_t, 64>::type;
// i32 // i32
using int32x2_t = typename vector_type<int32_t, 2>::type; using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type; using int32x4_t = typename vector_type<int32_t, 4>::type;
...@@ -1023,6 +1039,45 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x) ...@@ -1023,6 +1039,45 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x)
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bfloat16_t>(bfloat16_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bfloat16_t type_convert<bfloat16_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bf16
template <>
inline __host__ __device__ bfloat16_t type_convert<bfloat16_t, half_t>(half_t x)
{
union
{
float fp32;
uint32_t int32;
} u = {static_cast<float>(x)};
return uint16_t(u.int32 >> 16);
}
template <> template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, half2_t>(half2_t x) inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, half2_t>(half2_t x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment