Commit 3a477917 authored by zhanghj2's avatar zhanghj2
Browse files

FLASH_MLA_BF16_TYPE控制bf16转换精度

parent 4c0bb04e
...@@ -167,7 +167,12 @@ flash_fwd_mla_combine_kernel(const CombineParams params) { ...@@ -167,7 +167,12 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
// } // }
auto float2bf16 = [] (float s) -> uint16_t { auto float2bf16 = [] (float s) -> uint16_t {
uint32_t x32 = reinterpret_cast<uint32_t const &>(s); uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32 += 0x8000u; x32 += 0x8000u;
#endif
return uint16_t(x32 >> 16); return uint16_t(x32 >> 16);
}; };
......
...@@ -290,7 +290,14 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso ...@@ -290,7 +290,14 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
#else #else
{ {
if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) { if constexpr (std::is_same_v<To_type, cutlass::bfloat16_t>) {
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 0
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_toward_zero> convert_op;
#else
cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel, cutlass::FloatRoundStyle::round_half_ulp_truncate> convert_op;
#endif
*result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); *result_ptr = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
} else { } else {
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
......
...@@ -19,9 +19,14 @@ def is_flag_set(flag: str) -> bool: ...@@ -19,9 +19,14 @@ def is_flag_set(flag: str) -> bool:
return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
def get_features_args(): def get_features_args():
bf16_type = os.getenv("FLASH_MLA_BF16_TYPE", "0")
assert bf16_type == "0" or bf16_type == "1", "bf16_type must be 0 or 1"
bf16_mode_names = {"0": "round_toward_zero", "1": "round_half_ulp_truncate"}
print(f"Using BFloat16 rounding mode: {bf16_mode_names.get(bf16_type, 'unknown')}")
features_args = [] features_args = []
if is_flag_set("FLASH_MLA_DISABLE_FP16"): if is_flag_set("FLASH_MLA_DISABLE_FP16"):
features_args.append("-DFLASH_MLA_DISABLE_FP16") features_args.append("-DFLASH_MLA_DISABLE_FP16")
features_args.append(f"-DFLASH_MLA_BF16_TYPE={bf16_type}")
return features_args return features_args
def get_arch_flags(): def get_arch_flags():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment