Unverified Commit 56a724eb authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[QUANT] Add GPTQModel Dynamic Quantization + `lm_head` Quantization (#3790)


Signed-off-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
parent 583d6af7
...@@ -2,15 +2,25 @@ ...@@ -2,15 +2,25 @@
SGLang supports various quantization methods, including offline quantization and online dynamic quantization. SGLang supports various quantization methods, including offline quantization and online dynamic quantization.
Offline quantization loads pre-quantized model weights directly during inference. This is useful for methods requiring pre-computed stats such as AWQ, which collects activation stats from the pre-training set. Offline quantization loads pre-quantized model weights directly during inference. This is required for quantization methods
such as GPTQ and AWQ that collects and pre-compute various stats from the original weights using the calibration dataset.
Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime. Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors on-the-fly to convert high-precision weights into a lower-precision format. Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime.
Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors
on-the-fly to convert high-precision weights into a lower-precision format.
**Note that, for better performance, usability and convenience, offline quantization is recommended over online quantization.** And if you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time. For popular pre-quantized models, please visit [neuralmagic collection](https://huggingface.co/collections/neuralmagic) for some popular quantized LLMs on huggingface. **Note: For better performance, usability and convenience, offline quantization is recommended over online quantization.**
If you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time.
For popular pre-quantized models, please visit [ModelCloud](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2) or [NeuralMagic](https://huggingface.co/collections/neuralmagic) collections on HF for some
popular quality validated quantized models. Quantized models must be validated via benchmarks post-quantization
to guard against abnormal quantization loss regressions.
## Offline Quantization ## Offline Quantization
To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.** To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline,
there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the
downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.**
```bash ```bash
python3 -m sglang.launch_server \ python3 -m sglang.launch_server \
...@@ -18,9 +28,38 @@ python3 -m sglang.launch_server \ ...@@ -18,9 +28,38 @@ python3 -m sglang.launch_server \
--port 30000 --host 0.0.0.0 --port 30000 --host 0.0.0.0
``` ```
To do offline quantization for your model, firstly you need to install [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: ### Examples of Offline Model Quantization
#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel)
```bash
# install
pip install gptqmodel --no-build-isolation -v
```
```py
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"
calibration_dataset = load_dataset(
"allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz",
split="train"
).select(range(1024))["text"]
quant_config = QuantizeConfig(bits=4, group_size=128) # quantization config
model = GPTQModel.load(model_id, quant_config) # load model
model.quantize(calibration_dataset, batch_size=2) # quantize
model.save(quant_path) # save model
```
#### Using [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
```bash ```bash
# install
pip install llmcompressor pip install llmcompressor
``` ```
...@@ -99,8 +138,7 @@ python3 -m sglang.launch_server \ ...@@ -99,8 +138,7 @@ python3 -m sglang.launch_server \
## Reference ## Reference
- [quantization document of vllm](https://docs.vllm.ai/en/latest/quantization/fp8.html) - [GPTQModel](https://github.com/ModelCloud/GPTQModel)
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
- [torchao](https://github.com/pytorch/ao) - [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
- [llm-compressor](https://github.com/vllm-project/llm-compressor/)
...@@ -19,6 +19,7 @@ from sglang.srt.layers.linear import ( ...@@ -19,6 +19,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.utils import add_prefix
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
...@@ -122,20 +123,20 @@ class VisionAttention(nn.Module): ...@@ -122,20 +123,20 @@ class VisionAttention(nn.Module):
head_size=self.head_size, head_size=self.head_size,
total_num_heads=num_heads, total_num_heads=num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
else: else:
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
input_size=embed_dim, input_size=embed_dim,
output_size=3 * projection_size, output_size=3 * projection_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.proj = RowParallelLinear( self.proj = RowParallelLinear(
input_size=embed_dim, input_size=embed_dim,
output_size=embed_dim, output_size=embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=add_prefix("out_proj", prefix),
) )
def forward( def forward(
......
...@@ -417,7 +417,7 @@ class LogitsProcessor(nn.Module): ...@@ -417,7 +417,7 @@ class LogitsProcessor(nn.Module):
) )
else: else:
# GGUF models # GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.logit_scale is not None: if self.logit_scale is not None:
logits.mul_(self.logit_scale) logits.mul_(self.logit_scale)
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from typing import Callable, Dict, Optional, Type import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union
import torch import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
...@@ -16,8 +18,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi ...@@ -16,8 +18,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig ...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
...@@ -61,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -61,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
)
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if (
get_dynamic_override( # noqa: E712
cloned_config, layer_name=prefix # noqa: E712
)
== False
): # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None
def gptq_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, GPTQMarlinMoEMethod,
) )
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, FusedMoE):
return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
if isinstance(self, GPTQConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
)
return None return None
...@@ -155,6 +256,7 @@ def apply_monkey_patches(): ...@@ -155,6 +256,7 @@ def apply_monkey_patches():
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
setattr(AWQMoEMethod, "apply", awq_moe_method_apply) setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
......
import logging
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
import torch
from vllm.scalar_type import scalar_types
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
logger = logging.getLogger(__name__)
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
super().__init__()
self.dynamic = dynamic
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits."
)
def __repr__(self) -> str:
return (
f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}"
)
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
@classmethod
def get_name(cls) -> str:
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["GPTQLinearMethod"]:
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from sglang.srt.layers.quantization import get_linear_quant_method
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any],
) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self.dynamic = dynamic
self.weight_bits = weight_bits
self.is_sym = is_sym
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
)
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str:
return (
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}"
)
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(
weight_bits,
group_size,
desc_act,
is_sym,
lm_head_quantized,
dynamic,
config,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
)
if can_convert and is_valid_user_quant:
msg = (
"The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name())
)
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
# if layer.num_experts > 32:
# # For MoEs with many experts the moe_wna16 kernel is faster
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
# else:
# return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
)
from vllm.platforms import current_platform
if not current_platform.is_cuda():
return False
if quant_method != "gptq":
return False
# Marlin conversion is only valid if required properties are found
if num_bits is None or group_size is None or sym is None or desc_act is None:
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_marlin_supported(
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
)
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def __init__(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}"
)
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 64
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return (
f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})"
)
@classmethod
def get_name(cls) -> str:
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(group_size, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
)
if is_marlin_format and is_valid_user_quant:
msg = "The model is serialized in {} format. Using {} kernel.".format(
cls.get_name(), cls.get_name()
)
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
return MarlinLinearMethod(self)
return None
...@@ -34,6 +34,7 @@ class RadixAttention(nn.Module): ...@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1, v_head_dim: int = -1,
sliding_window_size: int = -1, sliding_window_size: int = -1,
is_cross_attention: bool = False, is_cross_attention: bool = False,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
......
...@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module):
) )
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
linear_method = None quant_method = None
if quant_config is not None: if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix) quant_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None: print("quant_method", quant_method)
linear_method = UnquantizedEmbeddingMethod() if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear # If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another # method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important. # layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding( quant_method_implements_embedding = method_has_implemented_embedding(
type(linear_method) type(quant_method)
) )
if is_embedding_layer and not linear_method_implements_embedding: if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError( raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement " f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod." "the 'embedding' method, see UnquantizedEmbeddingMethod."
) )
self.linear_method: QuantizeMethodBase = linear_method self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module):
- self.shard_indices.added_vocab_start_index - self.shard_indices.added_vocab_start_index
) )
self.linear_method.create_weights( self.quant_method.create_weights(
self, self,
self.embedding_dim, self.embedding_dim,
[self.num_embeddings_per_partition], [self.num_embeddings_per_partition],
...@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module):
packed_factor = ( packed_factor = (
param.packed_factor param.packed_factor
if isinstance(param, BasevLLMParameter) if isinstance(param, BasevLLMParameter)
else param.pack_factor else param.packed_factor
) )
assert loaded_weight.shape[output_dim] == ( assert loaded_weight.shape[output_dim] == (
self.org_vocab_size // param.packed_factor self.org_vocab_size // param.packed_factor
...@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module):
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = self.linear_method.embedding(self, masked_input.long()) output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding. # Mask the output embedding.
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......
...@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -80,13 +81,22 @@ class BaiChuanMLP(nn.Module): ...@@ -80,13 +81,22 @@ class BaiChuanMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -114,6 +124,7 @@ class BaiChuanAttention(nn.Module): ...@@ -114,6 +124,7 @@ class BaiChuanAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id: int = 0, layer_id: int = 0,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -167,6 +178,7 @@ class BaiChuanAttention(nn.Module): ...@@ -167,6 +178,7 @@ class BaiChuanAttention(nn.Module):
scaling, scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -182,6 +194,7 @@ class BaiChuanAttention(nn.Module): ...@@ -182,6 +194,7 @@ class BaiChuanAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -207,6 +220,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -207,6 +220,7 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding: str, position_embedding: str,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -220,12 +234,14 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -220,12 +234,14 @@ class BaiChuanDecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -264,6 +280,7 @@ class BaiChuanModel(nn.Module): ...@@ -264,6 +280,7 @@ class BaiChuanModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -281,6 +298,7 @@ class BaiChuanModel(nn.Module): ...@@ -281,6 +298,7 @@ class BaiChuanModel(nn.Module):
layer_id=i, layer_id=i,
position_embedding=position_embedding, position_embedding=position_embedding,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -330,18 +348,24 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -330,18 +348,24 @@ class BaiChuanBaseForCausalLM(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config) self.model = BaiChuanModel(
config, position_embedding, quant_config, prefix=add_prefix("model", prefix)
)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -404,11 +428,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -404,11 +428,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", quant_config) super().__init__(config, "ROPE", quant_config, prefix=prefix)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", quant_config) super().__init__(config, "ALIBI", quant_config, prefix=prefix)
EntryClass = [BaichuanForCausalLM] EntryClass = [BaichuanForCausalLM]
...@@ -41,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -41,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
LoraConfig = None LoraConfig = None
...@@ -51,6 +52,7 @@ class GLMAttention(nn.Module): ...@@ -51,6 +52,7 @@ class GLMAttention(nn.Module):
config, config,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -85,12 +87,14 @@ class GLMAttention(nn.Module): ...@@ -85,12 +87,14 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias, bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("query_key_value", prefix),
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("dense", prefix),
) )
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...@@ -109,6 +113,7 @@ class GLMAttention(nn.Module): ...@@ -109,6 +113,7 @@ class GLMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -142,6 +147,7 @@ class GLMMLP(nn.Module): ...@@ -142,6 +147,7 @@ class GLMMLP(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -153,6 +159,7 @@ class GLMMLP(nn.Module): ...@@ -153,6 +159,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2, [config.ffn_hidden_size] * 2,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("dense_h_to_4h", prefix),
) )
self.activation_func = SiluAndMul() self.activation_func = SiluAndMul()
...@@ -163,6 +170,7 @@ class GLMMLP(nn.Module): ...@@ -163,6 +170,7 @@ class GLMMLP(nn.Module):
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("dense_4h_to_h", prefix),
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -186,6 +194,7 @@ class GLMBlock(nn.Module): ...@@ -186,6 +194,7 @@ class GLMBlock(nn.Module):
config, config,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
...@@ -201,7 +210,9 @@ class GLMBlock(nn.Module): ...@@ -201,7 +210,9 @@ class GLMBlock(nn.Module):
) )
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, layer_id, quant_config) self.self_attention = GLMAttention(
config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix)
)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -210,7 +221,7 @@ class GLMBlock(nn.Module): ...@@ -210,7 +221,7 @@ class GLMBlock(nn.Module):
) )
# MLP # MLP
self.mlp = GLMMLP(config, quant_config) self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix))
def forward( def forward(
self, self,
...@@ -257,6 +268,7 @@ class GLMTransformer(nn.Module): ...@@ -257,6 +268,7 @@ class GLMTransformer(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.post_layer_norm = config.post_layer_norm self.post_layer_norm = config.post_layer_norm
...@@ -266,7 +278,15 @@ class GLMTransformer(nn.Module): ...@@ -266,7 +278,15 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[GLMBlock(config, i, quant_config) for i in range(self.num_layers)] [
GLMBlock(
config,
i,
quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(self.num_layers)
]
) )
if self.post_layer_norm: if self.post_layer_norm:
...@@ -301,19 +321,28 @@ class ChatGLMM(nn.Module): ...@@ -301,19 +321,28 @@ class ChatGLMM(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.embedding = VocabParallelEmbedding( self.embedding = VocabParallelEmbedding(
config.padded_vocab_size, config.hidden_size config.padded_vocab_size,
config.hidden_size,
prefix=add_prefix("embedding", prefix),
) )
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, quant_config) self.encoder = GLMTransformer(
config, quant_config, add_prefix("encoder", prefix)
)
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) self.output_layer = ParallelLMHead(
config.padded_vocab_size,
config.hidden_size,
prefix=add_prefix("output_layer", prefix),
)
def forward( def forward(
self, self,
...@@ -351,12 +380,15 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -351,12 +380,15 @@ class ChatGLMForCausalLM(nn.Module):
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMM(config, quant_config) self.transformer = ChatGLMM(
config, quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -65,7 +65,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -65,7 +65,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.utils import get_compiler_backend, set_weight_attrs from sglang.srt.utils import add_prefix, get_compiler_backend, set_weight_attrs
@torch.compile(backend=get_compiler_backend()) @torch.compile(backend=get_compiler_backend())
...@@ -110,6 +110,7 @@ class CohereMLP(nn.Module): ...@@ -110,6 +110,7 @@ class CohereMLP(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -120,12 +121,14 @@ class CohereMLP(nn.Module): ...@@ -120,12 +121,14 @@ class CohereMLP(nn.Module):
[self.intermediate_size] * 2, [self.intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -142,6 +145,7 @@ class CohereAttention(nn.Module): ...@@ -142,6 +145,7 @@ class CohereAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -177,12 +181,14 @@ class CohereAttention(nn.Module): ...@@ -177,12 +181,14 @@ class CohereAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -198,6 +204,7 @@ class CohereAttention(nn.Module): ...@@ -198,6 +204,7 @@ class CohereAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm( self.q_norm = LayerNorm(
...@@ -239,15 +246,23 @@ class CohereDecoderLayer(nn.Module): ...@@ -239,15 +246,23 @@ class CohereDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention( self.self_attn = CohereAttention(
config, layer_id=layer_id, quant_config=quant_config config,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = CohereMLP(config, quant_config=quant_config) self.mlp = CohereMLP(
config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
param_shape=(config.hidden_size), eps=config.layer_norm_eps param_shape=(config.hidden_size), eps=config.layer_norm_eps
) )
...@@ -279,6 +294,7 @@ class CohereModel(nn.Module): ...@@ -279,6 +294,7 @@ class CohereModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -288,7 +304,12 @@ class CohereModel(nn.Module): ...@@ -288,7 +304,12 @@ class CohereModel(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
CohereDecoderLayer(config, i, quant_config=quant_config) CohereDecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -321,12 +342,15 @@ class CohereForCausalLM(nn.Module): ...@@ -321,12 +342,15 @@ class CohereForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.model = CohereModel(config, quant_config) self.model = CohereModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
......
...@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import add_prefix, set_weight_attrs
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):
...@@ -58,6 +58,7 @@ class DbrxRouter(nn.Module): ...@@ -58,6 +58,7 @@ class DbrxRouter(nn.Module):
self, self,
config: DbrxConfig, config: DbrxConfig,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -89,6 +90,7 @@ class DbrxExperts(nn.Module): ...@@ -89,6 +90,7 @@ class DbrxExperts(nn.Module):
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -189,6 +191,7 @@ class DbrxAttention(nn.Module): ...@@ -189,6 +191,7 @@ class DbrxAttention(nn.Module):
config: DbrxConfig, config: DbrxConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -207,12 +210,14 @@ class DbrxAttention(nn.Module): ...@@ -207,12 +210,14 @@ class DbrxAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("Wqkv", prefix),
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.d_model, self.d_model,
self.d_model, self.d_model,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("out_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -244,6 +249,7 @@ class DbrxAttention(nn.Module): ...@@ -244,6 +249,7 @@ class DbrxAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -268,10 +274,16 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -268,10 +274,16 @@ class DbrxFusedNormAttention(nn.Module):
config: DbrxConfig, config: DbrxConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.attn = DbrxAttention(config, layer_id, quant_config=quant_config) self.attn = DbrxAttention(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
self.norm_1 = nn.LayerNorm(self.d_model) self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model)
...@@ -300,10 +312,14 @@ class DbrxBlock(nn.Module): ...@@ -300,10 +312,14 @@ class DbrxBlock(nn.Module):
config: DbrxConfig, config: DbrxConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention( self.norm_attn_norm = DbrxFusedNormAttention(
config, layer_id, quant_config=quant_config config,
layer_id,
quant_config=quant_config,
prefix=add_prefix("norm_attn_norm", prefix),
) )
self.ffn = DbrxExperts(config, quant_config=quant_config) self.ffn = DbrxExperts(config, quant_config=quant_config)
...@@ -328,6 +344,7 @@ class DbrxModel(nn.Module): ...@@ -328,6 +344,7 @@ class DbrxModel(nn.Module):
self, self,
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
...@@ -336,7 +353,12 @@ class DbrxModel(nn.Module): ...@@ -336,7 +353,12 @@ class DbrxModel(nn.Module):
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
DbrxBlock(config, i, quant_config=quant_config) DbrxBlock(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"blocks.{i}", prefix),
)
for i in range(config.n_layers) for i in range(config.n_layers)
] ]
) )
...@@ -369,17 +391,21 @@ class DbrxForCausalLM(nn.Module): ...@@ -369,17 +391,21 @@ class DbrxForCausalLM(nn.Module):
self, self,
config: DbrxConfig, config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, quant_config=quant_config) self.transformer = DbrxModel(
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -57,10 +58,15 @@ class DeepseekMLP(nn.Module): ...@@ -57,10 +58,15 @@ class DeepseekMLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -68,6 +74,7 @@ class DeepseekMLP(nn.Module): ...@@ -68,6 +74,7 @@ class DeepseekMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -89,6 +96,7 @@ class DeepseekMoE(nn.Module): ...@@ -89,6 +96,7 @@ class DeepseekMoE(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -110,6 +118,7 @@ class DeepseekMoE(nn.Module): ...@@ -110,6 +118,7 @@ class DeepseekMoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix(f"{idx}.experts", prefix),
) )
for idx in range(self.n_routed_experts) for idx in range(self.n_routed_experts)
] ]
...@@ -117,7 +126,11 @@ class DeepseekMoE(nn.Module): ...@@ -117,7 +126,11 @@ class DeepseekMoE(nn.Module):
self.pack_params() self.pack_params()
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None config.hidden_size,
self.n_routed_experts,
bias=False,
quant_config=None,
prefix=add_prefix("gate", prefix),
) )
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
...@@ -128,6 +141,7 @@ class DeepseekMoE(nn.Module): ...@@ -128,6 +141,7 @@ class DeepseekMoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
) )
def pack_params(self): def pack_params(self):
...@@ -185,6 +199,7 @@ class DeepseekAttention(nn.Module): ...@@ -185,6 +199,7 @@ class DeepseekAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -216,6 +231,7 @@ class DeepseekAttention(nn.Module): ...@@ -216,6 +231,7 @@ class DeepseekAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -223,6 +239,7 @@ class DeepseekAttention(nn.Module): ...@@ -223,6 +239,7 @@ class DeepseekAttention(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -238,6 +255,7 @@ class DeepseekAttention(nn.Module): ...@@ -238,6 +255,7 @@ class DeepseekAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -261,6 +279,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -261,6 +279,7 @@ class DeepseekDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -276,19 +295,25 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -276,19 +295,25 @@ class DeepseekDecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
if ( if (
config.n_routed_experts is not None config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0 and layer_id % config.moe_layer_freq == 0
): ):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config) self.mlp = DeepseekMoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else: else:
self.mlp = DeepseekMLP( self.mlp = DeepseekMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -328,6 +353,7 @@ class DeepseekModel(nn.Module): ...@@ -328,6 +353,7 @@ class DeepseekModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -339,7 +365,12 @@ class DeepseekModel(nn.Module): ...@@ -339,7 +365,12 @@ class DeepseekModel(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DeepseekDecoderLayer(config, layer_id, quant_config=quant_config) DeepseekDecoderLayer(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -368,13 +399,19 @@ class DeepseekForCausalLM(nn.Module): ...@@ -368,13 +399,19 @@ class DeepseekForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModel(config, quant_config) self.model = DeepseekModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -38,7 +38,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -38,7 +38,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import is_hip from sglang.srt.utils import add_prefix, is_hip
is_hip_ = is_hip() is_hip_ = is_hip()
...@@ -48,6 +48,7 @@ class DeepseekModelNextN(nn.Module): ...@@ -48,6 +48,7 @@ class DeepseekModelNextN(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -56,6 +57,7 @@ class DeepseekModelNextN(nn.Module): ...@@ -56,6 +57,7 @@ class DeepseekModelNextN(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
) )
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -64,7 +66,11 @@ class DeepseekModelNextN(nn.Module): ...@@ -64,7 +66,11 @@ class DeepseekModelNextN(nn.Module):
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.decoder = DeepseekV2DecoderLayer( self.decoder = DeepseekV2DecoderLayer(
config, 0, quant_config=quant_config, is_nextn=True config,
0,
quant_config=quant_config,
is_nextn=True,
prefix=add_prefix("decoder", prefix),
) )
self.shared_head = nn.Module() self.shared_head = nn.Module()
...@@ -108,18 +114,22 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -108,18 +114,22 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModelNextN(config, quant_config) self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
if global_server_args_dict["enable_dp_attention"]: if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear( self.lm_head = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.vocab_size, config.vocab_size,
bias=False, bias=False,
prefix=add_prefix("model.shared_head.head", prefix),
) )
self.logits_processor = LogitsProcessor(config, skip_all_gather=True) self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else: else:
...@@ -127,6 +137,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -127,6 +137,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -63,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -63,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_cuda_available, is_hip from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
is_hip_ = is_hip() is_hip_ = is_hip()
...@@ -79,10 +79,15 @@ class DeepseekV2MLP(nn.Module): ...@@ -79,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -90,6 +95,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -90,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -106,7 +112,11 @@ class DeepseekV2MLP(nn.Module): ...@@ -106,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
class MoEGate(nn.Module): class MoEGate(nn.Module):
def __init__(self, config): def __init__(
self,
config,
prefix: str = "",
):
super().__init__() super().__init__()
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size)) torch.empty((config.n_routed_experts, config.hidden_size))
...@@ -129,6 +139,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -129,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -147,7 +158,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -147,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now." "Only silu is supported for now."
) )
self.gate = MoEGate(config=config) self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl( self.experts = MoEImpl(
...@@ -161,6 +172,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -161,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group, num_expert_group=config.n_group,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
prefix=add_prefix("experts", prefix),
) )
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
...@@ -171,6 +183,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -171,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -217,6 +230,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -217,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -241,6 +255,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -241,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear( self.q_b_proj = ColumnParallelLinear(
...@@ -248,6 +263,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -248,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -255,6 +271,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -255,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
) )
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
...@@ -262,8 +279,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -262,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
# FIXME: quick fix for skip quantization prefix=add_prefix("kv_a_proj_with_mqa", prefix),
prefix=f"self_attn.kv_a_proj_with_mqa",
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -271,6 +287,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -271,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
) )
# O projection. # O projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -278,6 +295,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -278,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper( self.rotary_emb = get_rope_wrapper(
...@@ -303,6 +321,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -303,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -368,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -368,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
use_dp=False, use_dp=False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -394,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -394,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ReplicatedLinear( self.q_b_proj = ReplicatedLinear(
...@@ -401,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -401,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
) )
else: else:
self.q_proj = ReplicatedLinear( self.q_proj = ReplicatedLinear(
...@@ -408,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -408,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
) )
self.kv_b_proj = ReplicatedLinear( self.kv_b_proj = ReplicatedLinear(
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
) )
# O projection. # O projection.
self.o_proj = ReplicatedLinear( self.o_proj = ReplicatedLinear(
...@@ -421,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -421,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
else: else:
# For tensor parallel attention # For tensor parallel attention
...@@ -430,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -430,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear( self.q_b_proj = ColumnParallelLinear(
...@@ -437,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -437,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -444,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -444,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
) )
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
) )
# O projection. # O projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -457,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -457,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
...@@ -464,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -464,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
# FIXME: quick fix for skip quantization prefix=add_prefix("kv_a_proj_with_mqa", prefix),
prefix=f"self_attn.kv_a_proj_with_mqa",
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
...@@ -496,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -496,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=1, num_kv_heads=1,
layer_id=layer_id, layer_id=layer_id,
v_head_dim=self.kv_lora_rank, v_head_dim=self.kv_lora_rank,
prefix=add_prefix("attn_mqa", prefix),
) )
self.attn_mha = RadixAttention( self.attn_mha = RadixAttention(
...@@ -505,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -505,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
layer_id=layer_id, layer_id=layer_id,
v_head_dim=self.v_head_dim, v_head_dim=self.v_head_dim,
prefix=add_prefix("attn_mha", prefix),
) )
self.w_kc = None self.w_kc = None
...@@ -848,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -848,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False, is_nextn: bool = False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -880,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -880,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
use_dp=self.enable_dp_attention, use_dp=self.enable_dp_attention,
prefix=add_prefix("self_attn", prefix),
) )
else: else:
self.self_attn = DeepseekV2Attention( self.self_attn = DeepseekV2Attention(
...@@ -898,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -898,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
) )
if is_nextn or ( if is_nextn or (
config.n_routed_experts is not None config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0 and layer_id % config.moe_layer_freq == 0
): ):
self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -962,6 +1001,7 @@ class DeepseekV2Model(nn.Module): ...@@ -962,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_id = config.pad_token_id self.padding_id = config.pad_token_id
...@@ -978,6 +1018,7 @@ class DeepseekV2Model(nn.Module): ...@@ -978,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
config, config,
layer_id, layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
...@@ -1008,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1008,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekV2Model(config, quant_config) self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if global_server_args_dict["enable_dp_attention"]: if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear( self.lm_head = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.vocab_size, config.vocab_size,
bias=False, bias=False,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config, skip_all_gather=True) self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class ExaoneGatedMLP(nn.Module): class ExaoneGatedMLP(nn.Module):
...@@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module): ...@@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=add_prefix("gate_up_proj", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=add_prefix("c_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module): ...@@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=add_prefix("out_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style, rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
) )
self.mlp = ExaoneGatedMLP( self.mlp = ExaoneGatedMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.activation_function, hidden_act=config.activation_function,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=add_prefix("mlp", prefix),
) )
rms_norm_eps = config.layer_norm_epsilon rms_norm_eps = config.layer_norm_epsilon
self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps) self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
...@@ -244,6 +245,7 @@ class ExaoneModel(nn.Module): ...@@ -244,6 +245,7 @@ class ExaoneModel(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -256,7 +258,10 @@ class ExaoneModel(nn.Module): ...@@ -256,7 +258,10 @@ class ExaoneModel(nn.Module):
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
ExaoneDecoderLayer( ExaoneDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.h.{i}" config,
i,
quant_config=quant_config,
prefix=add_prefix(f"h.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module): ...@@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = ExaoneModel(config, quant_config=quant_config) self.transformer = ExaoneModel(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -37,6 +37,7 @@ from sglang.srt.layers.rotary_embedding import get_rope ...@@ -37,6 +37,7 @@ from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
...@@ -45,6 +46,7 @@ class GemmaMLP(nn.Module): ...@@ -45,6 +46,7 @@ class GemmaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -52,12 +54,14 @@ class GemmaMLP(nn.Module): ...@@ -52,12 +54,14 @@ class GemmaMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
self.act_fn = GeluAndMul("none") self.act_fn = GeluAndMul("none")
...@@ -79,6 +83,7 @@ class GemmaAttention(nn.Module): ...@@ -79,6 +83,7 @@ class GemmaAttention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -109,12 +114,14 @@ class GemmaAttention(nn.Module): ...@@ -109,12 +114,14 @@ class GemmaAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -130,6 +137,7 @@ class GemmaAttention(nn.Module): ...@@ -130,6 +137,7 @@ class GemmaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -152,6 +160,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -152,6 +160,7 @@ class GemmaDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -164,11 +173,13 @@ class GemmaDecoderLayer(nn.Module): ...@@ -164,11 +173,13 @@ class GemmaDecoderLayer(nn.Module):
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = GemmaMLP( self.mlp = GemmaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -205,6 +216,7 @@ class GemmaModel(nn.Module): ...@@ -205,6 +216,7 @@ class GemmaModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -215,7 +227,12 @@ class GemmaModel(nn.Module): ...@@ -215,7 +227,12 @@ class GemmaModel(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
GemmaDecoderLayer(config, i, quant_config=quant_config) GemmaDecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -277,11 +294,14 @@ class GemmaForCausalLM(nn.Module): ...@@ -277,11 +294,14 @@ class GemmaForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, quant_config=quant_config) self.model = GemmaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -39,7 +39,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -39,7 +39,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token # Aligned with HF's implementation, using sliding window inclusive with the last token
...@@ -56,13 +56,22 @@ class Gemma2MLP(nn.Module): ...@@ -56,13 +56,22 @@ class Gemma2MLP(nn.Module):
hidden_act: str, hidden_act: str,
hidden_activation: str, hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
raise ValueError( raise ValueError(
...@@ -91,6 +100,7 @@ class Gemma2Attention(nn.Module): ...@@ -91,6 +100,7 @@ class Gemma2Attention(nn.Module):
max_position_embeddings: int, max_position_embeddings: int,
rope_theta: float, rope_theta: float,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -123,12 +133,14 @@ class Gemma2Attention(nn.Module): ...@@ -123,12 +133,14 @@ class Gemma2Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -151,6 +163,7 @@ class Gemma2Attention(nn.Module): ...@@ -151,6 +163,7 @@ class Gemma2Attention(nn.Module):
if use_sliding_window if use_sliding_window
else None else None
), ),
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -173,6 +186,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -173,6 +186,7 @@ class Gemma2DecoderLayer(nn.Module):
layer_id: int, layer_id: int,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -186,6 +200,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -186,6 +200,7 @@ class Gemma2DecoderLayer(nn.Module):
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.mlp = Gemma2MLP( self.mlp = Gemma2MLP(
...@@ -194,6 +209,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -194,6 +209,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
hidden_activation=config.hidden_activation, hidden_activation=config.hidden_activation,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm( self.post_attention_layernorm = GemmaRMSNorm(
...@@ -238,6 +254,7 @@ class Gemma2Model(nn.Module): ...@@ -238,6 +254,7 @@ class Gemma2Model(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -253,7 +270,7 @@ class Gemma2Model(nn.Module): ...@@ -253,7 +270,7 @@ class Gemma2Model(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
), ),
prefix="", prefix=add_prefix("layers", prefix),
) )
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -339,11 +356,14 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -339,11 +356,14 @@ class Gemma2ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, quant_config) self.model = Gemma2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType ...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model
from sglang.srt.utils import add_prefix
class Gemma2ForSequenceClassification(nn.Module): class Gemma2ForSequenceClassification(nn.Module):
...@@ -29,12 +30,15 @@ class Gemma2ForSequenceClassification(nn.Module): ...@@ -29,12 +30,15 @@ class Gemma2ForSequenceClassification(nn.Module):
self, self,
config: Gemma2Config, config: Gemma2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.model = Gemma2Model(config, quant_config=quant_config) self.model = Gemma2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
......
...@@ -36,6 +36,7 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -36,6 +36,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -62,14 +63,14 @@ class GPT2Attention(nn.Module): ...@@ -62,14 +63,14 @@ class GPT2Attention(nn.Module):
total_num_heads, total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_attn", prefix=add_prefix("c_attn", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=add_prefix("c_proj", prefix),
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
...@@ -108,14 +109,14 @@ class GPT2MLP(nn.Module): ...@@ -108,14 +109,14 @@ class GPT2MLP(nn.Module):
intermediate_size, intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_fc", prefix=add_prefix("c_fc", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=add_prefix("c_proj", prefix),
) )
self.act = act_layer() self.act = act_layer()
...@@ -145,7 +146,7 @@ class GPT2Block(nn.Module): ...@@ -145,7 +146,7 @@ class GPT2Block(nn.Module):
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention( self.attn = GPT2Attention(
layer_id, config, quant_config, prefix=f"{prefix}.attn" layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
) )
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP( self.mlp = GPT2MLP(
...@@ -153,7 +154,7 @@ class GPT2Block(nn.Module): ...@@ -153,7 +154,7 @@ class GPT2Block(nn.Module):
config, config,
act_layer=act_layer, act_layer=act_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=add_prefix("mlp", prefix),
) )
def forward( def forward(
...@@ -196,7 +197,12 @@ class GPT2Model(nn.Module): ...@@ -196,7 +197,12 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
GPT2Block(i, config, quant_config=quant_config) GPT2Block(
i,
config,
quant_config=quant_config,
prefix=add_prefix(f"h.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -227,11 +233,14 @@ class GPT2LMHeadModel(nn.Module): ...@@ -227,11 +233,14 @@ class GPT2LMHeadModel(nn.Module):
self, self,
config: GPT2Config, config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, quant_config, prefix="transformer") self.transformer = GPT2Model(
config, quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module):
layer_id: int, layer_id: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module):
total_num_kv_heads, total_num_kv_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_attn", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
...@@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_proj", prefix),
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
...@@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
scaling=self.scale, scaling=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module): ...@@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module): ...@@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module):
intermediate_size, intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_fc", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_proj", prefix),
) )
self.act = get_act_fn( self.act = get_act_fn(
config.activation_function, quant_config, intermediate_size config.activation_function, quant_config, intermediate_size
...@@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module): ...@@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module):
layer_id: int, layer_id: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(layer_id, config, quant_config) self.attn = GPTBigCodeAttention(
layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config) self.mlp = GPTBigMLP(
inner_dim, config, quant_config, prefix=add_prefix("mlp", prefix)
)
def forward( def forward(
self, self,
...@@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module):
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module): ...@@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module):
lora_vocab = 0 lora_vocab = 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size,
prefix=add_prefix("wte", prefix),
) )
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
GPTBigCodeBlock(i, config, quant_config) GPTBigCodeBlock(
i, config, quant_config, prefix=add_prefix(f"h.{i}", prefix)
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module):
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, quant_config) self.transformer = GPTBigCodeModel(
config, quant_config, prefix=add_prefix("transformer", prefix)
)
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment