Unverified Commit 72c77763 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix linear.py and improve weight loading (#2851)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent 4093aa46
...@@ -39,7 +39,7 @@ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-r ...@@ -39,7 +39,7 @@ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-r
For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput. For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
### Example with OpenAI API ### Example: Sending requests with OpenAI API
```python3 ```python3
import openai import openai
...@@ -58,7 +58,8 @@ response = client.chat.completions.create( ...@@ -58,7 +58,8 @@ response = client.chat.completions.create(
) )
print(response) print(response)
``` ```
### Example serving with 2 H20*8
### Example: Serving with two H20*8 nodes
For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`.
```bash ```bash
...@@ -71,7 +72,7 @@ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --di ...@@ -71,7 +72,7 @@ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --di
If you have two H100 nodes, the usage is similar to the aforementioned H20. If you have two H100 nodes, the usage is similar to the aforementioned H20.
### Example serving with Docker two H200*8 nodes ### Example: Serving with two H200*8 nodes and docker
There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`. There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage. A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
- Mistral / Mixtral / Mistral NeMo - Mistral / Mixtral / Mistral NeMo
- Gemma / Gemma 2 - Gemma / Gemma 2
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
- DeepSeek / DeepSeek 2 - DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3)
- OLMoE - OLMoE
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava` - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava`
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
import logging import logging
from abc import abstractmethod from abc import abstractmethod
...@@ -16,7 +16,7 @@ from vllm.distributed import ( ...@@ -16,7 +16,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
# workaround # Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
...@@ -25,7 +25,6 @@ from sglang.srt.layers.parameter import ( ...@@ -25,7 +25,6 @@ from sglang.srt.layers.parameter import (
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter, RowvLLMParameter,
_ColumnvLLMParameter,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -43,9 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -43,9 +42,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQMarlinLinearMethod", "GPTQMarlinLinearMethod",
"Fp8LinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "MarlinLinearMethod",
"GPTQLinearMethod",
"QQQLinearMethod", "QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod",
] ]
...@@ -95,62 +98,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): ...@@ -95,62 +98,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight return param[shard_id], loaded_weight
def load_column_qkv_weight(
self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
):
if (
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
and self.output_dim == self.packed_dim
):
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size
)
param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_column_parallel_weight(
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
):
if isinstance(self, _ColumnvLLMParameter):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
else:
self.data.copy_(loaded_weight)
def load_row_parallel_weight(
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
):
if isinstance(self, RowvLLMParameter):
if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
else:
self.data.copy_(loaded_weight)
class LinearMethodBase(QuantizeMethodBase): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -426,9 +373,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -426,9 +373,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert ( assert param_data.shape == loaded_weight.shape
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
...@@ -437,7 +382,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -437,7 +382,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
...@@ -565,9 +510,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -565,9 +510,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data, loaded_weight, 0 param_data, loaded_weight, 0
) )
assert ( assert param_data.shape == loaded_weight.shape
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return return
current_shard_offset = 0 current_shard_offset = 0
...@@ -643,9 +586,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -643,9 +586,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"the same for all partitions." "the same for all partitions."
) )
assert ( assert param_data.shape == loaded_weight.shape
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint( def _load_fused_module_from_checkpoint(
...@@ -697,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -697,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
...@@ -882,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -882,6 +824,7 @@ class QKVParallelLinear(ColumnParallelLinear):
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight) param.load_qkv_weight(loaded_weight=loaded_weight)
return return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
...@@ -896,24 +839,14 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -896,24 +839,14 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (shard_offset + block_n - 1) // block_n shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n
if isinstance(param, _ColumnvLLMParameter): param.load_qkv_weight(
load_column_qkv_weight( loaded_weight=loaded_weight,
param, num_heads=self.num_kv_head_replicas,
loaded_weight, shard_id=loaded_shard_id,
num_heads=self.num_kv_head_replicas, shard_offset=shard_offset,
shard_id=loaded_shard_id, shard_size=shard_size,
shard_offset=shard_offset, tp_rank=self.tp_rank,
shard_size=shard_size, )
tp_rank=self.tp_rank,
)
else:
param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
)
def weight_loader( def weight_loader(
self, self,
...@@ -962,9 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -962,9 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data, loaded_weight, 0 param_data, loaded_weight, 0
) )
assert ( assert param_data.shape == loaded_weight.shape
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return return
shard_offsets = [ shard_offsets = [
...@@ -1105,9 +1036,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1105,9 +1036,7 @@ class QKVParallelLinear(ColumnParallelLinear):
"for all partitions." "for all partitions."
) )
assert ( assert param_data.shape == loaded_weight.shape
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -1234,9 +1163,7 @@ class RowParallelLinear(LinearBase): ...@@ -1234,9 +1163,7 @@ class RowParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert ( assert param_data.shape == loaded_weight.shape
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
...@@ -1247,7 +1174,18 @@ class RowParallelLinear(LinearBase): ...@@ -1247,7 +1174,18 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight) if isinstance(param, BasevLLMParameter):
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
# It supports additional parameters like tp_rank and use_presharded_weights.
param.load_row_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
else:
# `params` is defined in `vllm/model_executor/parameter.py`,
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)
def forward(self, input_): def forward(self, input_):
if self.input_is_parallel: if self.input_is_parallel:
......
...@@ -24,7 +24,9 @@ def fused_topk_native( ...@@ -24,7 +24,9 @@ def fused_topk_native(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
M, _ = hidden_states.shape M, _ = hidden_states.shape
topk_weights = torch.empty( topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device M, topk, dtype=torch.float32, device=hidden_states.device
...@@ -180,7 +182,7 @@ def select_experts( ...@@ -180,7 +182,7 @@ def select_experts(
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
) )
elif torch_native: elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native( topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
......
""" """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
Adapted from vLLM (0.6.4.post1).
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
"""
import logging import logging
from fractions import Fraction from fractions import Fraction
...@@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self): def output_dim(self):
return self._output_dim return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor): def load_column_parallel_weight(
tp_rank = get_tensor_model_parallel_rank() self,
shard_size = self.data.shape[self.output_dim] loaded_weight: torch.Tensor,
loaded_weight = loaded_weight.narrow( tp_rank: int,
self.output_dim, tp_rank * shard_size, shard_size use_presharded_weights: bool = False,
) ):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight) self.data.copy_(loaded_weight)
...@@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
shard_offset = kwargs.get("shard_offset") shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size") shard_size = kwargs.get("shard_size")
...@@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
) )
param_data = self.data param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
...@@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self): def input_dim(self):
return self._input_dim return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_row_parallel_weight(
use_presharded_weights = kwargs.get("use_presharded_weights") self,
tp_rank = get_tensor_model_parallel_rank() loaded_weight: torch.Tensor,
shard_size = self.data.shape[self.input_dim] tp_rank: int,
use_presharded_weights: bool = False,
):
if not use_presharded_weights: if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size self.input_dim, tp_rank * shard_size, shard_size
) )
...@@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter): ...@@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
# For row parallel layers, no sharding needed # For row parallel layers, no sharding needed
# load weight into parameter as is # load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs): def load_row_parallel_weight(self, *args, **kwargs):
kwargs.pop("tp_rank", None)
kwargs.pop("use_presharded_weights", None)
super().load_row_parallel_weight(*args, **kwargs) super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs): def load_merged_column_weight(self, *args, **kwargs):
...@@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter): ...@@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
self._load_into_shard_id(*args, **kwargs) self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs): def load_column_parallel_weight(self, *args, **kwargs):
kwargs.pop("tp_rank", None)
kwargs.pop("use_presharded_weights", None)
super().load_row_parallel_weight(*args, **kwargs) super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id( def _load_into_shard_id(
......
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
......
...@@ -11,9 +11,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -11,9 +11,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported, cutlass_fp8_supported,
requantize_with_max_scale, requantize_with_max_scale,
) )
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
......
...@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_tp: bool = True, enable_tp: bool = True,
use_presharded_weights: bool = False,
): ):
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
...@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.padding_size = padding_size self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size num_added_embeddings = num_embeddings - self.org_vocab_size
self.use_presharded_weights = use_presharded_weights
if use_presharded_weights:
assert (
num_added_embeddings == 0
), "Lora is not supported with presharded weights."
self.org_vocab_size_padded = pad_vocab_size( self.org_vocab_size_padded = pad_vocab_size(
self.org_vocab_size, self.padding_size self.org_vocab_size, self.padding_size
) )
...@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
start_idx = start_idx // packed_factor start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor shard_size = shard_size // packed_factor
else: else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size assert loaded_weight.shape[output_dim] == (
self.org_vocab_size
// (self.tp_size if self.use_presharded_weights else 1)
)
# Copy the data. # Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[: loaded_weight.shape[0]].data.copy_(loaded_weight) param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0] :].data.fill_(0) param[loaded_weight.shape[0] :].data.fill_(0)
...@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_presharded_weights: bool = False,
): ):
super().__init__( super().__init__(
num_embeddings, num_embeddings,
...@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size, padding_size,
quant_config, quant_config,
prefix, prefix,
use_presharded_weights=use_presharded_weights,
) )
self.quant_config = quant_config self.quant_config = quant_config
if bias: if bias:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker.""" """A scheduler that manages a tensor parallel GPU worker."""
import faulthandler
import logging import logging
import os import os
import signal import signal
...@@ -399,6 +400,8 @@ class Scheduler: ...@@ -399,6 +400,8 @@ class Scheduler:
self.watchdog_last_time = time.time() self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2) time.sleep(self.watchdog_timeout / 2)
# Wait sometimes so that the parent process can print the error.
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT) self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad() @torch.no_grad()
...@@ -1582,6 +1585,7 @@ def run_scheduler_process( ...@@ -1582,6 +1585,7 @@ def run_scheduler_process(
pipe_writer, pipe_writer,
): ):
setproctitle.setproctitle("sglang::scheduler") setproctitle.setproctitle("sglang::scheduler")
faulthandler.enable()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ: if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
......
...@@ -27,6 +27,7 @@ from enum import IntEnum ...@@ -27,6 +27,7 @@ from enum import IntEnum
from functools import wraps from functools import wraps
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np
import psutil import psutil
import torch import torch
...@@ -35,6 +36,8 @@ from sglang.srt.utils import debug_timing, get_compiler_backend ...@@ -35,6 +36,8 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
class ReqToTokenPool: class ReqToTokenPool:
"""A memory pool that maps a request to its token locations.""" """A memory pool that maps a request to its token locations."""
...@@ -193,6 +196,11 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -193,6 +196,11 @@ class MHATokenToKVPool(BaseTokenToKVPool):
self.layer_num = layer_num self.layer_num = layer_num
self._create_buffers() self._create_buffers()
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
)
def _create_buffers(self): def _create_buffers(self):
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
...@@ -217,6 +225,17 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -217,6 +225,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
del self.k_buffer del self.k_buffer
del self.v_buffer del self.v_buffer
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
return k_size_bytes, v_size_bytes
# Todo: different memory layout # Todo: different memory layout
def get_flat_data(self, indices): def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer # prepare a large chunk of contiguous data for efficient transfer
......
...@@ -611,6 +611,9 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -611,6 +611,9 @@ def _set_envs_and_config(server_args: ServerArgs):
# The child processes will send SIGQUIT to this process when any error happens # The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree # This process then clean up the whole process tree
def sigquit_handler(signum, frame): def sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child proces. It usually means the child failed."
)
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler) signal.signal(signal.SIGQUIT, sigquit_handler)
......
...@@ -71,7 +71,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.62) self.assertGreater(metrics["score"], 0.61)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment