Commit 1d82d465 authored by Jing Zhang's avatar Jing Zhang
Browse files

add bfp16 support

parent f03dda48
...@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) ...@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
......
...@@ -11,18 +11,16 @@ ...@@ -11,18 +11,16 @@
namespace ck { namespace ck {
//https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX); int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX); int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8 const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; // 1/16 const int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79 const int ADD = 0xd480d480; //-79
...@@ -40,17 +38,6 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -40,17 +38,6 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{ {
#if 0
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l - 8);
auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8);
return {h_f16, l_f16};
#elif 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
int x_l = (x_u8 & 0x0f); int x_l = (x_u8 & 0x0f);
...@@ -62,12 +49,51 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -62,12 +49,51 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
int lo = (x_l | x_h) | EX; int lo = (x_l | x_h) | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
int32_t res = bit_cast<int8_t>(q);
return bit_cast<half2_t>(res);
#endif
} }
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s)
{
uint32_t q = bit_cast<uint16_t>(i4s);
uint32_t i8s = (q & 0xf) | (q & 0xf0 << 4) | (q & 0xf00 << 8) | (q & 0xf000 << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388616.f;
fp32_intermediates[1] -= 8388616.f;
fp32_intermediates[2] -= 8388616.f;
fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<0>{}) = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
res.template AsType<bhalf2_t>()(Number<1>{}) = __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[2], 0x7632);
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
float x_h = ((x_u8 & 0x0f) >> 0) - 8;
float x_l = ((x_u8 & 0xf0) >> 4) - 8;
vector_type<bhalf_t, 2> res;
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
return res.template AsType<bhalf2_t>()[Number<0>{}];
}
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
...@@ -102,6 +128,32 @@ struct PassThroughPack8 ...@@ -102,6 +128,32 @@ struct PassThroughPack8
#endif #endif
} }
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<bhalf2_t>()(Number<0>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<bhalf2_t>()(Number<1>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<bhalf2_t>()(Number<2>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<bhalf2_t>()(Number<3>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true; constexpr const static bool is_pack8_invocable = true;
}; };
......
...@@ -1147,36 +1147,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1147,36 +1147,7 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value && if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
is_same<remove_cvref_t<DstData>, half_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, f8_t>::value)
{ {
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment