Unverified Commit 06487f12 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

refactor model loader: initial refactor (#664)

parent 39c57317
......@@ -91,4 +91,10 @@ python3 run_all.py
```
cd test/srt
python test_openai_server.py
```
\ No newline at end of file
```
## Format
pip3 install pre-commit
cd sglang
pre-commit install
pre-commit run --all-files
\ No newline at end of file
......@@ -123,6 +123,15 @@ class ModelRunner:
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
if (
self.server_args.efficient_weight_load
and "llama" in self.server_args.model_path.lower()
and self.server_args.quantization == "fp8"
):
from sglang.srt.model_loader.model_loader import get_model
else:
from vllm.model_executor.model_loader import get_model
self.model = get_model(
model_config=vllm_model_config,
device_config=device_config,
......@@ -237,7 +246,16 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list)
)
self.cuda_graph_runner.capture(batch_size_list)
logger.info(f"Capture for batch sizes {batch_size_list}")
try:
self.cuda_graph_runner.capture(batch_size_list)
except:
raise Exception(
f"Capture cuda graph failed. Possible solutions:\n"
f"1. disable cuda graph by --disable-cuda-graph\n"
f"2. set --mem-fraction-static to a smaller value\n"
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
)
@torch.inference_mode()
def forward_decode(self, batch: Batch):
......
......@@ -304,6 +304,12 @@ class ModelTpServer:
self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
)
if req.sampling_params.max_new_tokens < 0:
req.origin_input_ids = req.origin_input_ids[
: self.max_total_num_tokens - 128
]
logger.error("Request longer than memory pool size, truncated!!!")
self.forward_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]:
......
......@@ -91,6 +91,7 @@ def _initialize_model(
config=model_config.hf_config,
cache_config=cache_config,
quant_config=quant_config,
efficient_weight_load=True,
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
)
......
......@@ -15,11 +15,6 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
MergedColumnParallelLinear = None
QKVParallelLinear = None
RowParallelLinear = None
class LlamaMLP(nn.Module):
def __init__(
......@@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module):
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None:
global MergedColumnParallelLinear
global QKVParallelLinear
global RowParallelLinear
if efficient_weight_load:
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
else:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
super().__init__()
self.config = config
self.quant_config = quant_config
......@@ -288,7 +305,30 @@ class LlamaForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def get_module_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -298,15 +338,14 @@ class LlamaForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
def load_weights_per_param(name, loaded_weight):
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
return
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
......@@ -323,12 +362,21 @@ class LlamaForCausalLM(nn.Module):
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
return
if name.startswith("model.vision_tower") and name not in params_dict:
continue
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
load_weights_per_param(name, loaded_weight)
EntryClass = LlamaForCausalLM
......@@ -57,6 +57,7 @@ class ServerArgs:
disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False
efficient_weight_load: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
......@@ -327,6 +328,11 @@ class ServerArgs:
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--efficient-weight-load",
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment