Unverified Commit f9a7d9b3 authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

support server arg override KV cache to bf16 to avoid slow cases (#11749)

parent a93f10a7
...@@ -112,7 +112,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -112,7 +112,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto | | `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto |
| `--quantization` | The quantization method. | None | | `--quantization` | The quantization method. | None |
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None | | `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto | | `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'bf16' or 'bfloat16' for BF16 KV cache. 'fp8_e5m2' and 'fp8_e4m3' are supported for CUDA 11.8+. | auto |
| `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False | | `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False |
## Memory and scheduling ## Memory and scheduling
......
...@@ -1567,6 +1567,8 @@ class ModelRunner: ...@@ -1567,6 +1567,8 @@ class ModelRunner:
self.kv_cache_dtype = torch.float8_e4m3fnuz self.kv_cache_dtype = torch.float8_e4m3fnuz
else: else:
self.kv_cache_dtype = torch.float8_e4m3fn self.kv_cache_dtype = torch.float8_e4m3fn
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
self.kv_cache_dtype = torch.bfloat16
else: else:
raise ValueError( raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
......
...@@ -1652,8 +1652,8 @@ class ServerArgs: ...@@ -1652,8 +1652,8 @@ class ServerArgs:
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
default=ServerArgs.kv_cache_dtype, default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2", "fp8_e4m3"], choices=["auto", "fp8_e5m2", "fp8_e4m3", "bf16", "bfloat16"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', help='Data type for kv cache storage. "auto" will use model data type. "bf16" or "bfloat16" for BF16 KV cache. "fp8_e5m2" and "fp8_e4m3" are supported for CUDA 11.8+.',
) )
parser.add_argument( parser.add_argument(
"--enable-fp32-lm-head", "--enable-fp32-lm-head",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment