Unverified Commit 252e0f7b authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

fix: small bug for llama-405b fp16 (#733)

parent 7f6f2f0f
...@@ -121,7 +121,7 @@ class ModelRunner: ...@@ -121,7 +121,7 @@ class ModelRunner:
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
if is_llama3_405b_fp8(self.model_config): if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8 self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8 vllm_model_config.hf_config.num_key_value_heads = 8
......
...@@ -626,6 +626,7 @@ def is_llama3_405b_fp8(model_config): ...@@ -626,6 +626,7 @@ def is_llama3_405b_fp8(model_config):
and model_config.hf_config.intermediate_size == 53248 and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126 and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.num_key_value_heads == 16 and model_config.hf_config.num_key_value_heads == 16
and hasattr(model_config.hf_config, "quantization_config")
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
): ):
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment