"vscode:/vscode.git/clone" did not exist on "0874dd04dc1bb359053935109dc95483218b086f"
Unverified Commit 1dce6c48 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] support the case where num_attention_heads or intermediate_size is not...

[CPU] support the case where num_attention_heads or intermediate_size is not divisible by the TP size (#6771)
parent 9fcc9a80
from __future__ import annotations
from typing import TYPE_CHECKING
DEFAULT_MOE_PADDING_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
def may_get_weight_block_size(model_config, load_config):
from sglang.srt.model_loader.loader import _get_quantization_config
from sglang.srt.model_loader.utils import get_model_architecture
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
if quant_config is not None and hasattr(quant_config, "weight_block_size"):
return getattr(quant_config, "weight_block_size")
return None
def get_moe_padding_size(weight_block_size):
if weight_block_size is not None:
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
assert (
len(weight_block_size) == 2
), "Only len(weight_block_size) == 2 is supported"
assert (
weight_block_size[0] == weight_block_size[1]
), "Only weight_block_size[0] == weight_block_size[1] is supported"
return weight_block_size[0]
return DEFAULT_MOE_PADDING_SIZE
def get_num_heads_padding_size(tp_size, weight_block_size):
pad_size = (
tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
)
return pad_size
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
if hasattr(model_config.hf_config, attr_name):
attr_value = getattr(model_config.hf_config, attr_name)
if attr_value % intermediate_padding_size != 0:
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
setattr(model_config.hf_config, attr_name, attr_value)
setattr(model_config.hf_text_config, attr_name, attr_value)
return model_config
def adjust_config_with_unaligned_cpu_tp(
model_config: ModelConfig, load_config: LoadConfig, tp_size: int
) -> ModelConfig:
# Support the case where the num_attention_heads is not divisible by the TP size.
weight_block_size = may_get_weight_block_size(model_config, load_config)
model_config.hf_config.original_num_attention_heads = (
model_config.num_attention_heads
)
model_config.hf_text_config.original_num_attention_heads = (
model_config.num_attention_heads
)
model_config.hf_config.original_total_num_kv_heads = (
model_config.get_total_num_kv_heads()
)
model_config.hf_text_config.original_total_num_kv_heads = (
model_config.get_total_num_kv_heads()
)
if (
model_config.num_attention_heads % tp_size != 0
or model_config.get_total_num_kv_heads() % tp_size != 0
):
# Compute the head_dim using the model_config.num_attention_heads before padding
if not hasattr(model_config.hf_config, "head_dim"):
model_config.hf_config.head_dim = (
model_config.hidden_size // model_config.num_attention_heads
)
query_heads_per_kv = (
model_config.num_attention_heads // model_config.get_total_num_kv_heads()
)
total_kv_heads = model_config.get_total_num_kv_heads()
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
model_config.num_key_value_heads = num_key_value_heads
model_config.hf_config.num_key_value_heads = num_key_value_heads
model_config.hf_text_config.num_key_value_heads = num_key_value_heads
num_attention_heads = num_key_value_heads * query_heads_per_kv
model_config.num_attention_heads = num_attention_heads
model_config.hf_config.num_attention_heads = num_attention_heads
model_config.hf_text_config.num_attention_heads = num_attention_heads
intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
model_config = update_intermediate_size(
model_config, "moe_intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size", intermediate_padding_size
)
return model_config
......@@ -426,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if _is_cpu:
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
output_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
......@@ -644,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = self.tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if _is_cpu:
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
output_dim,
shard_size,
not use_bitsandbytes_4bit and not self.use_presharded_weights,
)
else:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
......@@ -1112,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_id = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if _is_cpu:
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
output_dim,
shard_size,
not use_bitsandbytes_4bit and not self.use_presharded_weights,
)
else:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
# Special case for for AQLM codebooks.
elif is_metadata:
......@@ -1257,7 +1311,22 @@ class RowParallelLinear(LinearBase):
):
shard_size = param_data.shape[input_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
if _is_cpu:
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
input_dim,
shard_size,
)
else:
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
......
......@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -573,11 +574,6 @@ class FusedMoE(torch.nn.Module):
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
......@@ -588,7 +584,24 @@ class FusedMoE(torch.nn.Module):
start = shard_size
else:
start = 0
expert_data = expert_data.narrow(shard_dim, start, shard_size)
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
start,
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(
......@@ -605,10 +618,21 @@ class FusedMoE(torch.nn.Module):
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
0, # param_data_start
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
......
......@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from sglang.srt.utils import is_cpu
__all__ = [
"BasevLLMParameter",
"PackedvLLMParameter",
......@@ -21,6 +23,8 @@ __all__ = [
logger = logging.getLogger(__name__)
_is_cpu = is_cpu()
class BasevLLMParameter(Parameter):
"""
......@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
self.data,
loaded_weight,
0, # param_data_start
tp_rank * shard_size,
self.output_dim,
shard_size,
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
else:
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
......@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
tp_rank * shard_size,
self.output_dim,
shard_size,
not use_presharded_weights,
)
else:
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
):
if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
self.data,
loaded_weight,
0, # param_data_start
tp_rank * shard_size,
self.input_dim,
shard_size,
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
else:
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
......
......@@ -246,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module):
self.tp_size = 1
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
# Support the case where the vocab size is not divisible by the TP size.
if (
_is_cpu
and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
):
padding_size *= self.tp_size
self.padding_size = padding_size
num_added_embeddings = num_embeddings - self.org_vocab_size
self.use_presharded_weights = use_presharded_weights
if use_presharded_weights:
......
......@@ -149,6 +149,7 @@ from sglang.srt.utils import (
get_available_gpu_memory,
get_bool_env_var,
get_zmq_socket,
is_cpu,
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
......@@ -167,6 +168,8 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu()
@dataclass
class GenerationBatchResult:
......@@ -2115,11 +2118,14 @@ class Scheduler(
"kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
),
"cuda_graph": round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
),
"token_capacity": int(self.max_total_num_tokens),
}
if not _is_cpu:
ret["memory_usage"]["cuda_graph"] = round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
)
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count
......
......@@ -29,6 +29,7 @@ import torch.distributed as dist
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import (
get_tp_group,
......@@ -165,7 +166,6 @@ class ModelRunner:
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
):
# Parse args
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
......@@ -178,6 +178,7 @@ class ModelRunner:
self.dp_size = server_args.dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.model_config = model_config
self.dist_port = nccl_port
self.server_args = server_args
self.is_draft_worker = is_draft_worker
......@@ -604,6 +605,10 @@ class ModelRunner:
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
self.model_config, self.load_config, self.tp_size
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
......
......@@ -961,3 +961,57 @@ def kv_cache_scales_loader(
tp_rank,
)
return []
def get_actual_shard_size(shard_size, weight_start, weight_end):
if weight_end < weight_start:
return 0
return min(shard_size, weight_end - weight_start)
def reset_param_data_if_needed(param_data, dim, start, length):
if length == 0:
return
assert length > 0, f"Length should be positive, but got {length}"
param_data.narrow(dim, start, length).zero_()
return
def narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
param_data_start,
weight_start,
dim,
shard_size,
narrow_weight=True,
):
actual_shard_size = get_actual_shard_size(
shard_size, weight_start, loaded_weight.size(dim)
)
if narrow_weight:
if actual_shard_size > 0:
loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
else:
# No real data to load; create a dummy tensor filled with zeros
loaded_weight = torch.zeros_like(
param_data.narrow(dim, param_data_start, actual_shard_size)
)
# [Note] Reset padded weights to zero.
# If the actual shard size is less than the shard size, we need to reset
# the padded param_data to zero and then copy the loaded_weight into it.
reset_param_data_if_needed(
param_data,
dim,
param_data_start + actual_shard_size,
shard_size - actual_shard_size,
)
param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
return param_data, loaded_weight
......@@ -16,7 +16,9 @@ from sglang.srt.managers.mm_utils import (
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, is_cpu
_is_cpu = is_cpu()
class Llama4ForConditionalGeneration(nn.Module):
......@@ -107,13 +109,17 @@ class Llama4ForConditionalGeneration(nn.Module):
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_key_value_heads
)
if _is_cpu:
dim = self.language_model.config.original_total_num_kv_heads
else:
dim = self.language_model.config.num_key_value_heads
loaded_weight = permute(loaded_weight, dim)
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_attention_heads
)
if _is_cpu:
dim = self.language_model.config.original_num_attention_heads
else:
dim = self.language_model.config.num_attention_heads
loaded_weight = permute(loaded_weight, dim)
return name, loaded_weight
......
......@@ -100,6 +100,7 @@ class Qwen2Attention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: Optional[int] = None,
layer_id: int = 0,
rope_theta: float = 1000000,
rope_scaling: Optional[Dict[str, Any]] = None,
......@@ -123,7 +124,10 @@ class Qwen2Attention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
if head_dim is not None:
self.head_dim = head_dim
else:
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
......@@ -191,10 +195,12 @@ class Qwen2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=head_dim,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
......
......@@ -13,6 +13,8 @@
# ==============================================================================
"""Common utilities."""
from __future__ import annotations
import base64
import builtins
import ctypes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment