Unverified Commit d63f13c1 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

fix: fp8 config (#723)

parent fded6744
...@@ -15,6 +15,7 @@ from flashinfer import ( ...@@ -15,6 +15,7 @@ from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from torch.nn.parameter import Parameter
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import ( from vllm.distributed import (
...@@ -22,6 +23,7 @@ from vllm.distributed import ( ...@@ -22,6 +23,7 @@ from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
...@@ -38,6 +40,18 @@ from sglang.srt.utils import ( ...@@ -38,6 +40,18 @@ from sglang.srt.utils import (
logger = logging.getLogger("srt.model_runner") logger = logging.getLogger("srt.model_runner")
def is_llama3_405b_fp8(model_config):
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
and model_config.hf_config.hidden_size == 16384
and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False
class ModelRunner: class ModelRunner:
def __init__( def __init__(
self, self,
...@@ -118,6 +132,9 @@ class ModelRunner: ...@@ -118,6 +132,9 @@ class ModelRunner:
seed=42, seed=42,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
if is_llama3_405b_fp8(self.model_config):
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
self.dtype = vllm_model_config.dtype self.dtype = vllm_model_config.dtype
if self.model_config.model_overide_args is not None: if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args) vllm_model_config.hf_config.update(self.model_config.model_overide_args)
...@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: ...@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
return model_arch_name_to_cls[model_arch] return model_arch_name_to_cls[model_arch]
def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
2 * i * head_dim : (2 * i + 1) * head_dim, :
]
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight
def get_weight_loader_srt(weight_loader):
def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
if (
loaded_shard_id in ["k", "v"]
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
):
loaded_weight = get_original_weight(loaded_weight, self.head_size)
weight_loader(self, param, loaded_weight, loaded_shard_id)
return weight_loader_srt
# Monkey patch model loader # Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
original_weight_loader = QKVParallelLinear.weight_loader
setattr(
QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment