"docs/en/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e417035f5d473b9f85d15ba01267d48d7f30e71e"
Unverified Commit b8addae2 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

[CK_TILE] float -> bf16 inline asm rtn (#1482)



* asm rtn

* add asm rtn macro

* reorder macro

---------
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
parent 461ec98d
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......
...@@ -17,6 +17,7 @@ enum class bf16_rounding_mode ...@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard = 0, // rtn standard = 0, // rtn
truncate_with_nan, truncate_with_nan,
truncate, truncate,
standard_asm,
}; };
template <bf16_rounding_mode rounding = template <bf16_rounding_mode rounding =
...@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f) ...@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
CK_TILE_HOST
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rtn_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
uint32_t tmp;
asm volatile("\n \
v_cmp_u_f32 %0, %2, %2 \n \
v_bfe_u32 %1, %2, 16, 1 \n \
v_add3_u32 %1, %2, %1, %3 \n \
v_cndmask_b32 %2, %1, %4, %0 \n \
v_lshrrev_b32 %2, 16, %2 \n \
"
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
return uint16_t(u.int32);
}
// Truncate instead of rounding, preserving SNaN // Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f) constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
...@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round ...@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{ {
if constexpr(rounding == bf16_rounding_mode::standard) if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f); return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f); return float_to_bf16_truc_nan_raw(f);
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment