Unverified Commit 72c77763 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix linear.py and improve weight loading (#2851)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent 4093aa46
......@@ -39,7 +39,7 @@ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-r
For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
### Example with OpenAI API
### Example: Sending requests with OpenAI API
```python3
import openai
......@@ -58,7 +58,8 @@ response = client.chat.completions.create(
)
print(response)
```
### Example serving with 2 H20*8
### Example: Serving with two H20*8 nodes
For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`.
```bash
......@@ -71,7 +72,7 @@ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --di
If you have two H100 nodes, the usage is similar to the aforementioned H20.
### Example serving with Docker two H200*8 nodes
### Example: Serving with two H200*8 nodes and docker
There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
......
......@@ -5,7 +5,7 @@
- Mistral / Mixtral / Mistral NeMo
- Gemma / Gemma 2
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
- DeepSeek / DeepSeek 2
- DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3)
- OLMoE
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava`
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
import logging
from abc import abstractmethod
......@@ -16,7 +16,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
# workaround
# Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
from vllm.model_executor.layers.linear import LinearBase
from sglang.srt.layers.parameter import (
......@@ -25,7 +25,6 @@ from sglang.srt.layers.parameter import (
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter,
_ColumnvLLMParameter,
)
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
......@@ -43,9 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"GPTQLinearMethod",
"QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod",
]
......@@ -95,62 +98,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight
def load_column_qkv_weight(
self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
):
if (
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
and self.output_dim == self.packed_dim
):
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size
)
param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_column_parallel_weight(
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
):
if isinstance(self, _ColumnvLLMParameter):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
else:
self.data.copy_(loaded_weight)
def load_row_parallel_weight(
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
):
if isinstance(self, RowvLLMParameter):
if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
else:
self.data.copy_(loaded_weight)
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
......@@ -426,9 +373,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
......@@ -437,7 +382,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
......@@ -565,9 +510,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data, loaded_weight, 0
)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
......@@ -643,9 +586,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"the same for all partitions."
)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint(
......@@ -697,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
......@@ -882,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear):
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
......@@ -896,24 +839,14 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
if isinstance(param, _ColumnvLLMParameter):
load_column_qkv_weight(
param,
loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
)
else:
param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
)
param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
)
def weight_loader(
self,
......@@ -962,9 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data, loaded_weight, 0
)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
......@@ -1105,9 +1036,7 @@ class QKVParallelLinear(ColumnParallelLinear):
"for all partitions."
)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -1234,9 +1163,7 @@ class RowParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
......@@ -1247,7 +1174,18 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
if isinstance(param, BasevLLMParameter):
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
# It supports additional parameters like tp_rank and use_presharded_weights.
param.load_row_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
else:
# `params` is defined in `vllm/model_executor/parameter.py`,
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)
def forward(self, input_):
if self.input_is_parallel:
......
......@@ -24,7 +24,9 @@ def fused_topk_native(
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
......@@ -180,7 +182,7 @@ def select_experts(
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif torch_native:
elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
......
"""
Adapted from vLLM (0.6.4.post1).
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
"""
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
import logging
from fractions import Fraction
......@@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
def load_column_parallel_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
......@@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
......@@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
......@@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
use_presharded_weights = kwargs.get("use_presharded_weights")
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
def load_row_parallel_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
):
if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
......@@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
kwargs.pop("tp_rank", None)
kwargs.pop("use_presharded_weights", None)
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs):
......@@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs):
kwargs.pop("tp_rank", None)
kwargs.pop("use_presharded_weights", None)
super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(
......
from typing import List, Optional, Tuple
import torch
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
......
......@@ -11,9 +11,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......
......@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_tp: bool = True,
use_presharded_weights: bool = False,
):
super().__init__()
self.quant_config = quant_config
......@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.use_presharded_weights = use_presharded_weights
if use_presharded_weights:
assert (
num_added_embeddings == 0
), "Lora is not supported with presharded weights."
self.org_vocab_size_padded = pad_vocab_size(
self.org_vocab_size, self.padding_size
)
......@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size
// (self.tp_size if self.use_presharded_weights else 1)
)
# Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0] :].data.fill_(0)
......@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_presharded_weights: bool = False,
):
super().__init__(
num_embeddings,
......@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size,
quant_config,
prefix,
use_presharded_weights=use_presharded_weights,
)
self.quant_config = quant_config
if bias:
......
......@@ -13,6 +13,7 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""
import faulthandler
import logging
import os
import signal
......@@ -399,6 +400,8 @@ class Scheduler:
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
# Wait sometimes so that the parent process can print the error.
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad()
......@@ -1582,6 +1585,7 @@ def run_scheduler_process(
pipe_writer,
):
setproctitle.setproctitle("sglang::scheduler")
faulthandler.enable()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
......
......@@ -27,6 +27,7 @@ from enum import IntEnum
from functools import wraps
from typing import List, Tuple, Union
import numpy as np
import psutil
import torch
......@@ -35,6 +36,8 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
......@@ -193,6 +196,11 @@ class MHATokenToKVPool(BaseTokenToKVPool):
self.layer_num = layer_num
self._create_buffers()
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
......@@ -217,6 +225,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
del self.k_buffer
del self.v_buffer
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
return k_size_bytes, v_size_bytes
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
......
......@@ -611,6 +611,9 @@ def _set_envs_and_config(server_args: ServerArgs):
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child proces. It usually means the child failed."
)
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
......
......@@ -71,7 +71,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.62)
self.assertGreater(metrics["score"], 0.61)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment