Unverified Commit 10d76548 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

`FusedMoE` support for the Transformers backend (#22650)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 39b643dc
...@@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models] ...@@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models]
### Transformers ### Transformers
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
Currently, the Transformers backend works for the following: Currently, the Transformers backend works for the following:
- Modalities: embedding models, language models and vision-language models* - Modalities: embedding models, language models and vision-language models*
- Architectures: encoder-only, decoder-only - Architectures: encoder-only, decoder-only, mixture-of-experts
- Attention types: full attention and/or sliding attention - Attention types: full attention and/or sliding attention
_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ _*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
...@@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus ...@@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) - All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes: - Any combination of the following vLLM parallelisation schemes:
- Data parallel
- Pipeline parallel - Pipeline parallel
- Tensor parallel - Tensor parallel
......
...@@ -661,6 +661,10 @@ _TRANSFORMERS_BACKEND_MODELS = { ...@@ -661,6 +661,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501 "TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
} }
_EXAMPLE_MODELS = { _EXAMPLE_MODELS = {
......
...@@ -66,6 +66,7 @@ def check_implementation( ...@@ -66,6 +66,7 @@ def check_implementation(
[ [
("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE
]) # trust_remote_code=True by default ]) # trust_remote_code=True by default
def test_models( def test_models(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
...@@ -74,6 +75,14 @@ def test_models( ...@@ -74,6 +75,14 @@ def test_models(
model: str, model: str,
model_impl: str, model_impl: str,
) -> None: ) -> None:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip("MoE models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
check_implementation(hf_runner, check_implementation(hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
......
...@@ -430,6 +430,17 @@ def dummy_hf_overrides( ...@@ -430,6 +430,17 @@ def dummy_hf_overrides(
update_dict = { update_dict = {
"num_layers": num_layers, "num_layers": num_layers,
# For Gemma-3n
"num_kv_shared_layers": 1,
}
class DummyConfig:
hf_text_config = text_config
# Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
if ModelConfig.get_num_experts(DummyConfig) > 0:
update_dict.update({
"num_experts": num_experts, "num_experts": num_experts,
"num_experts_per_tok": 2, "num_experts_per_tok": 2,
"num_local_experts": num_experts, "num_local_experts": num_experts,
...@@ -437,9 +448,7 @@ def dummy_hf_overrides( ...@@ -437,9 +448,7 @@ def dummy_hf_overrides(
"first_k_dense_replace": 0, "first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3 # To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts, "n_routed_experts": num_experts,
# For Gemma-3n })
"num_kv_shared_layers": 1,
}
# Update num_hidden_layers for non-Longcat architectures # Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \ if model_arch != "LongcatFlashForCausalLM" \
......
...@@ -20,7 +20,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, ...@@ -20,7 +20,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig) MultiModalConfig)
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
...@@ -667,6 +667,8 @@ class ModelConfig: ...@@ -667,6 +667,8 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str: def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if """Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`.""" `model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults # Check if the architecture we're wrapping has defaults
runner = None runner = None
convert = None convert = None
...@@ -685,15 +687,15 @@ class ModelConfig: ...@@ -685,15 +687,15 @@ class ModelConfig:
# Resolve Transformers backend pooling classes # Resolve Transformers backend pooling classes
if runner == "pooling": if runner == "pooling":
if convert == "embed": if convert == "embed":
return "TransformersEmbeddingModel" return prefix + "EmbeddingModel"
if convert == "classify": if convert == "classify":
return "TransformersForSequenceClassification" return prefix + "ForSequenceClassification"
# Resolve Transformers backend generate classes # Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config: if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is # If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal # probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM" return prefix + "ForMultimodalLM"
return "TransformersForCausalLM" return prefix + "ForCausalLM"
def using_transformers_backend(self) -> bool: def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class.""" """Check if the model is using the Transformers backend class."""
...@@ -1025,17 +1027,7 @@ class ModelConfig: ...@@ -1025,17 +1027,7 @@ class ModelConfig:
self.enforce_eager = True self.enforce_eager = True
def _verify_with_expert_parallelism(self) -> None: def _verify_with_expert_parallelism(self) -> None:
num_expert_names = [ num_experts = self.get_num_experts()
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(self.hf_text_config, name, 0)
if num_experts > 0:
break
if num_experts < 1: if num_experts < 1:
raise ValueError( raise ValueError(
"Number of experts in the model must be greater than 0 " "Number of experts in the model must be greater than 0 "
...@@ -1220,6 +1212,21 @@ class ModelConfig: ...@@ -1220,6 +1212,21 @@ class ModelConfig:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size return num_heads // parallel_config.tensor_parallel_size
def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
"moe_num_experts", # Dbrx
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
return num_experts[0]
return num_experts
def get_layers_start_end_indices( def get_layers_start_end_indices(
self, parallel_config: ParallelConfig) -> tuple[int, int]: self, parallel_config: ParallelConfig) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
......
...@@ -960,6 +960,7 @@ class FusedMoE(CustomOp): ...@@ -960,6 +960,7 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False, is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0, zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None, zero_expert_type: Optional[str] = None,
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
...@@ -996,6 +997,9 @@ class FusedMoE(CustomOp): ...@@ -996,6 +997,9 @@ class FusedMoE(CustomOp):
self.zero_expert_num = zero_expert_num self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type self.zero_expert_type = zero_expert_type
# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping
# Round up hidden size if needed. # Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype, hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
quant_config, quant_config,
...@@ -1617,6 +1621,33 @@ class FusedMoE(CustomOp): ...@@ -1617,6 +1621,33 @@ class FusedMoE(CustomOp):
return False if return_success else None return False if return_success else None
def load_weights(
self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Iterable[str]:
if (expert_mapping := self.expert_mapping) is None:
raise ValueError("`self.expert_mapping` must be provided to "
"load weights using `self.load_weights`.")
for expert_name, loaded_weight in weights:
qual_name = f"{self.layer_name}.{expert_name}"
for param_name, weight_name, expert_id, shard_id in expert_mapping:
if weight_name not in qual_name:
continue
weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
logger.debug("Loaded %s for expert %d into %s", param_name,
expert_id, self.layer_name)
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]: def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters()) weights = list(self.named_parameters())
assert all(weight.is_contiguous() for _, weight in weights) assert all(weight.is_contiguous() for _, weight in weights)
......
...@@ -307,10 +307,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = { ...@@ -307,10 +307,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
} }
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
"TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501
"TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501
} }
# yapf: enable # yapf: enable
......
...@@ -22,6 +22,8 @@ from typing import Literal, Optional, Union ...@@ -22,6 +22,8 @@ from typing import Literal, Optional, Union
import regex as re import regex as re
import torch import torch
import transformers
from packaging.version import Version
from torch import nn from torch import nn
from transformers import (AutoModel, BatchFeature, PretrainedConfig, from transformers import (AutoModel, BatchFeature, PretrainedConfig,
PreTrainedModel) PreTrainedModel)
...@@ -35,6 +37,7 @@ from vllm.config.utils import getattr_iter ...@@ -35,6 +37,7 @@ from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
...@@ -121,10 +124,14 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: ...@@ -121,10 +124,14 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
return enable return enable
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep",
"replicate"]
def replace_linear_class( def replace_linear_class(
linear: nn.Linear, linear: nn.Linear,
style: Literal["colwise", "rowwise"], style: Style = "replicate",
quant_config: QuantizationConfig, quant_config: Optional[QuantizationConfig] = None,
*, *,
prefix: str = "", prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
...@@ -132,11 +139,11 @@ def replace_linear_class( ...@@ -132,11 +139,11 @@ def replace_linear_class(
Replace nn.Linear with one of vLLM's tensor parallel linear classes. Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args: Args:
linear (nn.Linear): `nn.Linear` to be replaced. linear: `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise". style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear. quant_config: Quantization config for the new linear.
Returns: Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear. The new linear.
""" """
if not isinstance(style, str): if not isinstance(style, str):
...@@ -166,6 +173,31 @@ def replace_linear_class( ...@@ -166,6 +173,31 @@ def replace_linear_class(
) )
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
This method assumes:
- Weight is stored as `weight`.
- Epsilon is stored as `eps` or `variance_epsilon`.
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
kwargs = {
"hidden_size": hidden_size,
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
"has_weight": getattr(rms_norm, "with_scale", True)
}
if (weight := getattr(rms_norm, "weight", None)) is not None:
# If weight is a Parameter, get its data tensor
weight = getattr(weight, "data", weight)
kwargs["dtype"] = weight.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
return RMSNorm(**kwargs)
# Copied from `accelerate` # Copied from `accelerate`
@contextmanager @contextmanager
def init_on_device_without_buffers(device: torch.device): def init_on_device_without_buffers(device: torch.device):
...@@ -463,8 +495,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -463,8 +495,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.ignore_unexpected_suffixes: list[str] = [] self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes.""" """Ignore unexpected weights whose qualname ends with these suffixes."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
raise NotImplementedError("Transformers backend does not "
"support MXFP4 quantization yet.")
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if self.quant_config and "gptq" in self.quant_config.get_name(): if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias") self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors # Set correct attn and init on "meta" to delay allocating GPU tensors
...@@ -478,8 +516,12 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -478,8 +516,12 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel() self.pipeline_parallel()
self.tensor_parallel() # Substitute remaining layers with vLLM's layers as needed
self.recursive_replace()
# Create attention instances for KV cache allocation
self.attention_instances = self.create_attention_instances()
# Input embeddings # Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
...@@ -494,12 +536,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -494,12 +536,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
quant_config=self.quant_config, quant_config=self.quant_config,
)) ))
# Attention layers
self.attention_instances = self.create_attention_instances()
# Initialize any parameters that have not had their modules replaced # Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model) self.init_parameters(self.model)
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size)) ["hidden_states"], self.text_config.hidden_size))
...@@ -558,56 +598,53 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -558,56 +598,53 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer()) setattr(self.model, name, PPMissingLayer())
def tensor_parallel(self): def recursive_replace(self):
""" """Recursively replace modules in the model as needed.
Apply the model's tensor parallelization plan.
Currently only supports linear layers. Currently, this replaces:
- `nn.Linear` with vLLM's tensor parallel linear classes
- `*RMSNorm` with vLLM's `RMSNorm`
""" """
# Look for tp plans in all of the PreTrainedModels found in self.model tp_plan = self.model.tp_plan
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
pretrained_models = filter(is_pretrained_model, self.model.modules())
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
if not any(models_with_tp_plan) and self.tp_size > 1: if not tp_plan and self.tp_size > 1:
tip = get_feature_request_tip(self.model_config.model, tip = get_feature_request_tip(self.model_config.model,
self.model_config.trust_remote_code) self.model_config.trust_remote_code)
raise ValueError( raise ValueError(
f"{type(self.model)} does not support tensor parallel. {tip}") f"{type(self.model)} does not support tensor parallel. {tip}")
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None): # Prefix the patterns because we always start from `self.model`
tp_plan = tp_plan or {} tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
# If the current module is a PreTrainedModel, set the tp_plan for
# all of its children
if isinstance(module, PreTrainedModel):
tp_plan = module.config.base_model_tp_plan or {}
tp_plan = {
maybe_prefix(prefix, k): v
for k, v in tp_plan.items()
}
# Some weight loaders expect linear layers to inherit from vLLM's def _recursive_replace(module: nn.Module, prefix: str):
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name) qual_name = maybe_prefix(prefix, child_name)
if isinstance(child_module, nn.Linear): if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name)) generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None) pattern = next(generator, None)
# Some weight loaders expect all linear layers to inherit
# LinearBase, so we set a default style which causes any
# unspecified layers to be replaced with ReplicatedLinear
style = tp_plan.get(pattern, "replicate") style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(child_module, new_module = replace_linear_class(child_module,
style, style,
self.quant_config, self.quant_config,
prefix=qual_name) prefix=qual_name)
# TODO(hmellor): Enable RMSNorm replacement once we have a way
# to choose RMSNorm vs GemmaRMSNorm
# elif child_module.__class__.__name__.endswith("RMSNorm"):
# new_module = replace_rms_norm_class(
# child_module, self.config.hidden_size)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module) setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module) log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module,
prefix=qual_name,
tp_plan=tp_plan)
_tensor_parallel(self.model, prefix="model") _recursive_replace(self.model, prefix="model")
def create_attention_instances( def create_attention_instances(
self, self,
...@@ -657,15 +694,21 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -657,15 +694,21 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.model: PreTrainedModel = AutoModel.from_config(...) self.model: PreTrainedModel = AutoModel.from_config(...)
``` ```
""" """
def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
new_param = nn.Parameter( new_param = nn.Parameter(
torch.empty_like(param.data, torch.empty_like(
param.data,
dtype=dtype or self.model_config.dtype, dtype=dtype or self.model_config.dtype,
device=self.device_config.device)) device=self.device_config.device,
))
setattr(module, name, new_param) setattr(module, name, new_param)
for child in module.children(): for child in module.children():
self.init_parameters(child, dtype) _init_parameters(child, dtype)
_init_parameters(module, dtype)
def forward( def forward(
self, self,
...@@ -702,8 +745,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -702,8 +745,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str, def load_weights(
torch.Tensor]]) -> set[str]: self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=self.skip_prefixes, skip_prefixes=self.skip_prefixes,
...@@ -713,6 +758,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -713,6 +758,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
) )
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def check_version(self, min_version: str, feature: str):
installed = Version(transformers.__version__)
required = Version(min_version)
if installed < required:
raise ImportError(
f"Transformers backend requires transformers>={required} "
f"for {feature}, but got {installed}")
@support_torch_compile(enable_if=can_enable_torch_compile) @support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase): class TransformersForCausalLM(TransformersBase):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` MoE models."""
from typing import Any
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .transformers import (TransformersBase, TransformersForCausalLM,
TransformersForMultimodalLM,
can_enable_torch_compile, log_replacement)
from .utils import maybe_prefix
@CustomOp.register("transformers_fused_moe")
class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers backend."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._top_k_index: torch.Tensor = None
def custom_routing_function(hidden_states, gating_output, topk,
renormalize):
"""Return `top_k_weights` from `gating_output` and the
`top_k_index` we stored in the layer earlier."""
return gating_output, self._top_k_index
self.custom_routing_function = custom_routing_function
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""In Transformers `experts.forward` will have this signature.
We discard any extra kwargs because we cannot use them here."""
return torch.ops.vllm.transformers_moe_forward(hidden_states,
top_k_index,
top_k_weights,
self.layer_name)
def transformers_moe_forward(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
"""Store the `top_k_index` in the layer and call the actual forward."""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._top_k_index = top_k_index
# Clone hidden_states because it will be mutated in-place in FusedMoE
return self.forward_impl(hidden_states.clone(), top_k_weights)
def transformers_moe_forward_fake(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="transformers_moe_forward",
op_func=transformers_moe_forward,
mutates_args=["hidden_states"],
fake_impl=transformers_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class TransformersMoEBase(TransformersBase):
def __init__(self, *, vllm_config, prefix=""):
self.check_version("4.57.0.dev0", "MoE models support")
super().__init__(vllm_config=vllm_config, prefix=prefix)
if self.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"Transformers backend does not support expert parallel yet.")
if self.parallel_config.enable_eplb:
raise NotImplementedError(
"Transformers backend does not support expert parallel load "
"balancing yet.")
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""
Params for weights, fp8 weight scales, fp8 activation scales
(param_name, weight_name, expert_id, shard_id)
"""
ckpt_names = [
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
("linear", "linear_1", "linear_v"), # Grok1 style
]
expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend(
FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name=gate_proj,
ckpt_down_proj_name=down_proj,
ckpt_up_proj_name=up_proj,
num_experts=self.model_config.get_num_experts(),
num_redundant_experts=0, # TODO: enable EPLB
))
return expert_mapping
def recursive_replace(self):
"""Initialize the MoE layers."""
text_config = self.text_config
# Positional arguments
num_experts = self.model_config.get_num_experts()
top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"],
None)
assert top_k is not None
hidden_size = text_config.hidden_size
intermediate_size = getattr_iter(
text_config, ["moe_intermediate_size", "intermediate_size"], None)
assert intermediate_size is not None
# If there are shared experts, the results are
# reduced after mlp.forward() not inside FusedMoE
num_experts_shared = getattr_iter(text_config, [
"num_experts_shared", "n_shared_experts", "moe_num_shared_experts"
], 0)
reduce_results = num_experts_shared == 0
def add_all_reduce(mlp: nn.Module):
"""Adds an all-reduce to the output of `mlp.forward()`."""
class MLPWithAllReduce(mlp.__class__):
def forward(self, *args, **kwargs):
output = super().forward(*args, **kwargs)
return self.experts.maybe_all_reduce_tensor_model_parallel(
output)
mlp.__class__ = MLPWithAllReduce
# Unused kwargs since we use custom_routing_function:
# - `scoring_func` and `e_score_correction_bias` only used for grouped
# topk routing inside vLLM and are non-trivial to infer
# and hard code `use_grouped_topk=False`
# - `renormalize` passed anyway because it's easy to infer
# - `num_expert_group` and `topk_group` used for inferring expert
# placement strategy in FusedMoE
# - `apply_router_weight_on_input` is already applied in Transformers
renormalize = getattr(text_config, "norm_topk_prob", top_k > 1)
num_expert_group = getattr(text_config, "n_group", None)
topk_group = getattr(text_config, "topk_group", None)
# MoE activation function
activation = "silu"
wrapped_arch = self.config.architectures[0].lower()
if "gptoss" in wrapped_arch:
activation = "swigluoai"
elif "grok1" in wrapped_arch:
activation = "gelu"
# Expert mapping for `AutoWeightsLoader`
expert_mapping = self.get_expert_mapping()
# Configs
parallel_config = self.parallel_config
eplb_config = parallel_config.eplb_config
# Expert parallel load balancing kwargs
enable_eplb = parallel_config.enable_eplb
num_redundant_experts = eplb_config.num_redundant_experts
# Recursively fuse MoE layers
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
if (child_name == "experts"
and isinstance(child_module, nn.ModuleList)):
# Alias for readability
mlp = module
experts = child_module
# Do the experts have biases
has_bias = False
for experts_param_name, _ in experts.named_parameters():
if "bias" in experts_param_name:
has_bias = True
break
# Double check there are no shared experts
nonlocal reduce_results
if reduce_results:
for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name:
reduce_results = False
break
# Replace experts module with FusedMoE
fused_experts = TransformersFusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=reduce_results,
renormalize=renormalize,
# Hard coded because topk happens in Transformers
use_grouped_topk=False,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=self.quant_config,
prefix=qual_name,
activation=activation,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
has_bias=has_bias,
expert_mapping=expert_mapping,
)
mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts)
# If results are not all-reduced in FusedMoE, ensure they
# are all-reduced at the end of mlp.forward() if tensor
# parallel or expert parallel is enabled
if not reduce_results and (fused_experts.tp_size > 1
or fused_experts.ep_size > 1):
add_all_reduce(mlp)
else:
_recursive_replace(child_module, prefix=qual_name)
_recursive_replace(self.model, prefix="model")
# Continue with the replacement of layers in TransformersBase
super().recursive_replace()
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
pass
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
},
enable_if=can_enable_torch_compile)
class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM,
TransformersForMultimodalLM):
pass
...@@ -20,7 +20,7 @@ from typing import Optional, Union ...@@ -20,7 +20,7 @@ from typing import Optional, Union
import torch import torch
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
from vllm.attention import AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
...@@ -29,6 +29,7 @@ from vllm.sequence import IntermediateTensors ...@@ -29,6 +29,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces_base import VllmModelForPooling from .interfaces_base import VllmModelForPooling
from .transformers import TransformersBase, can_enable_torch_compile from .transformers import TransformersBase, can_enable_torch_compile
from .transformers_moe import TransformersMoEBase
from .utils import WeightsMapper from .utils import WeightsMapper
...@@ -79,7 +80,9 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling): ...@@ -79,7 +80,9 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
self.padding_idx = self.text_config.pad_token_id self.padding_idx = self.text_config.pad_token_id
def create_attention_instances( def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER): self,
attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
# TODO(hmellor): Better way to detect encoder models # TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False` # In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True) is_encoder = lambda m: not getattr(m, "is_causal", True)
...@@ -90,14 +93,7 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling): ...@@ -90,14 +93,7 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
# Check minimum transformers version for encoder models support # Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY: if attn_type == AttentionType.ENCODER_ONLY:
import transformers self.check_version("4.57.0.dev0", "encoder models support")
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")
return super().create_attention_instances(attn_type) return super().create_attention_instances(attn_type)
...@@ -198,3 +194,15 @@ class TransformersForSequenceClassification(TransformersPoolingBase): ...@@ -198,3 +194,15 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
vllm_config.model_config), vllm_config.model_config),
), ),
}) })
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(TransformersMoEBase,
TransformersEmbeddingModel):
pass
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
TransformersMoEBase, TransformersForSequenceClassification):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment