Unverified Commit 06487f12 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

refactor model loader: initial refactor (#664)

parent 39c57317
...@@ -91,4 +91,10 @@ python3 run_all.py ...@@ -91,4 +91,10 @@ python3 run_all.py
``` ```
cd test/srt cd test/srt
python test_openai_server.py python test_openai_server.py
``` ```
\ No newline at end of file
## Format
pip3 install pre-commit
cd sglang
pre-commit install
pre-commit run --all-files
\ No newline at end of file
...@@ -123,6 +123,15 @@ class ModelRunner: ...@@ -123,6 +123,15 @@ class ModelRunner:
if self.model_config.model_overide_args is not None: if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args) vllm_model_config.hf_config.update(self.model_config.model_overide_args)
if (
self.server_args.efficient_weight_load
and "llama" in self.server_args.model_path.lower()
and self.server_args.quantization == "fp8"
):
from sglang.srt.model_loader.model_loader import get_model
else:
from vllm.model_executor.model_loader import get_model
self.model = get_model( self.model = get_model(
model_config=vllm_model_config, model_config=vllm_model_config,
device_config=device_config, device_config=device_config,
...@@ -237,7 +246,16 @@ class ModelRunner: ...@@ -237,7 +246,16 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner( self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list) self, max_batch_size_to_capture=max(batch_size_list)
) )
self.cuda_graph_runner.capture(batch_size_list) logger.info(f"Capture for batch sizes {batch_size_list}")
try:
self.cuda_graph_runner.capture(batch_size_list)
except:
raise Exception(
f"Capture cuda graph failed. Possible solutions:\n"
f"1. disable cuda graph by --disable-cuda-graph\n"
f"2. set --mem-fraction-static to a smaller value\n"
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
)
@torch.inference_mode() @torch.inference_mode()
def forward_decode(self, batch: Batch): def forward_decode(self, batch: Batch):
......
...@@ -304,6 +304,12 @@ class ModelTpServer: ...@@ -304,6 +304,12 @@ class ModelTpServer:
self.model_config.context_len - 1 - len(req.origin_input_ids), self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_tokens - 128 - len(req.origin_input_ids), self.max_total_num_tokens - 128 - len(req.origin_input_ids),
) )
if req.sampling_params.max_new_tokens < 0:
req.origin_input_ids = req.origin_input_ids[
: self.max_total_num_tokens - 128
]
logger.error("Request longer than memory pool size, truncated!!!")
self.forward_queue.append(req) self.forward_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]: def get_new_prefill_batch(self) -> Optional[Batch]:
......
...@@ -91,6 +91,7 @@ def _initialize_model( ...@@ -91,6 +91,7 @@ def _initialize_model(
config=model_config.hf_config, config=model_config.hf_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
efficient_weight_load=True,
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config), **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
) )
......
...@@ -15,11 +15,6 @@ from vllm.distributed import ( ...@@ -15,11 +15,6 @@ from vllm.distributed import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
MergedColumnParallelLinear = None
QKVParallelLinear = None
RowParallelLinear = None
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__( def __init__(
...@@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module): ...@@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module):
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None: ) -> None:
global MergedColumnParallelLinear
global QKVParallelLinear
global RowParallelLinear
if efficient_weight_load:
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
else:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
...@@ -288,7 +305,30 @@ class LlamaForCausalLM(nn.Module): ...@@ -288,7 +305,30 @@ class LlamaForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def get_module_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -298,15 +338,14 @@ class LlamaForCausalLM(nn.Module): ...@@ -298,15 +338,14 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) def load_weights_per_param(name, loaded_weight):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue return
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -323,12 +362,21 @@ class LlamaForCausalLM(nn.Module): ...@@ -323,12 +362,21 @@ class LlamaForCausalLM(nn.Module):
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue return
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue return
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
load_weights_per_param(name, loaded_weight)
EntryClass = LlamaForCausalLM EntryClass = LlamaForCausalLM
...@@ -57,6 +57,7 @@ class ServerArgs: ...@@ -57,6 +57,7 @@ class ServerArgs:
disable_disk_cache: bool = False disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
efficient_weight_load: bool = False
# Distributed args # Distributed args
nccl_init_addr: Optional[str] = None nccl_init_addr: Optional[str] = None
...@@ -327,6 +328,11 @@ class ServerArgs: ...@@ -327,6 +328,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
) )
parser.add_argument(
"--efficient-weight-load",
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment