Commit 261aa80c authored by zhanghj2's avatar zhanghj2
Browse files

FLASH_MLA_BF16_TYPE=1

parent fd342087
...@@ -19,7 +19,7 @@ def is_flag_set(flag: str) -> bool: ...@@ -19,7 +19,7 @@ def is_flag_set(flag: str) -> bool:
return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
def get_features_args(): def get_features_args():
bf16_type = os.getenv("FLASH_MLA_BF16_TYPE", "0") bf16_type = os.getenv("FLASH_MLA_BF16_TYPE", "1")
assert bf16_type == "0" or bf16_type == "1", "bf16_type must be 0 or 1" assert bf16_type == "0" or bf16_type == "1", "bf16_type must be 0 or 1"
bf16_mode_names = {"0": "round_toward_zero", "1": "round_half_ulp_truncate"} bf16_mode_names = {"0": "round_toward_zero", "1": "round_half_ulp_truncate"}
print(f"Using BFloat16 rounding mode: {bf16_mode_names.get(bf16_type, 'unknown')}") print(f"Using BFloat16 rounding mode: {bf16_mode_names.get(bf16_type, 'unknown')}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment