Unverified Commit 85bd6599 authored by Jennifer He's avatar Jennifer He Committed by GitHub
Browse files

[Model] Add AutoWeightsLoader support for BERT, RoBERTa (#20534)


Signed-off-by: default avatarJennifer He <islandhe@gmail.com>
Signed-off-by: <islandhe@gmail.com>
Signed-off-by: default avatarJen H <islandhe@gmail.com>
parent 91b3d190
...@@ -22,12 +22,11 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, ...@@ -22,12 +22,11 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class BertEmbedding(nn.Module): class BertEmbedding(nn.Module):
...@@ -44,9 +43,11 @@ class BertEmbedding(nn.Module): ...@@ -44,9 +43,11 @@ class BertEmbedding(nn.Module):
config.type_vocab_size, config.hidden_size) config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute": if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" + raise ValueError("Only 'absolute' position_embedding_type" +
...@@ -358,45 +359,45 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -358,45 +359,45 @@ class BertModel(nn.Module, SupportsQuant):
("qkv_proj", "value", "v"), ("qkv_proj", "value", "v"),
] ]
loaded_stacked_params = []
other_weights = []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.pooler is None and "pooler" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. if name not in params_dict:
if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
loaded_stacked_params.append(name)
break break
else: else:
# Skip loading extra bias for GPTQ models. if name in params_dict:
if name.endswith(".bias") and name not in params_dict: other_weights.append((name, loaded_weight))
continue
param = params_dict[name] loader = AutoWeightsLoader(
weight_loader = getattr(param, "weight_loader", self,
default_weight_loader) skip_prefixes=(["pooler."] if self.pooler is None else []),
weight_loader(param, loaded_weight) )
loaded_params.add(name) loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions. embedding operations and customized pooling functions.
Attributes: Attributes:
model: An instance of BertModel used for forward operations. model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -425,10 +426,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -425,10 +426,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights) weights_list = list(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head.")) has_model_prefix = any(
self.model.load_weights(weights) name.startswith("model.") for name, _ in weights_list)
if not has_model_prefix:
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
return loader.load_weights(weights_list, mapper=mapper)
def _build_model(self, def _build_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -470,26 +476,9 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, ...@@ -470,26 +476,9 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
self.classifier, self.bert.pooler) self.classifier, self.bert.pooler)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
self_weights = [] loaded_params = loader.load_weights(weights)
return loaded_params
def weight_filter():
for name, weight in weights:
if name.startswith("bert."):
yield (name[len("bert."):], weight)
else:
self_weights.append((name, weight))
self.bert.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def pooler( def pooler(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
...@@ -13,9 +12,9 @@ from vllm.config import VllmConfig ...@@ -13,9 +12,9 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
...@@ -39,8 +38,10 @@ class RobertaEmbedding(nn.Module): ...@@ -39,8 +38,10 @@ class RobertaEmbedding(nn.Module):
config.hidden_size) config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.position_ids = nn.Parameter( self.register_buffer(
torch.empty((1, config.max_position_embeddings)), ) "position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute": if self.position_embedding_type != "absolute":
...@@ -136,16 +137,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -136,16 +137,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
embedding_class=RobertaEmbedding) embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights) weights_list = list(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory). has_roberta_prefix = any(
# For use with models like FacebookAI/roberta-base. name.startswith("roberta.") for name, _ in weights_list)
bert_weights, task_weights = roberta_task_weights_filter(weights) if has_roberta_prefix:
loaded = self.model.load_weights(bert_weights) # For models with the `roberta.` prefix e.g.
if not len(loaded): # `FacebookAI/roberta-base`
# Fix for models like `sentence-transformers/stsb-roberta-base-v2` mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
# which use the same architecture, but have no "roberta" prefix. else:
loaded = self.model.load_weights(task_weights) # For models without the `roberta.` prefix e.g.
assert len(loaded), "Unable to load RobertaEmbeddingModel" # `sentence-transformers/stsb-roberta-base-v2`
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
return loader.load_weights(weights_list, mapper=mapper)
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...@@ -187,19 +192,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -187,19 +192,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.classifier) self.classifier)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights) loader = AutoWeightsLoader(self)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
self.roberta.load_weights(bert_weights)
params_dict = dict(self.named_parameters())
for name, loaded_weight in task_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def pooler( def pooler(
self, self,
...@@ -245,27 +239,3 @@ def create_position_ids_from_input_ids(input_ids, ...@@ -245,27 +239,3 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask past_key_values_length) * mask
return incremental_indices.long() + padding_idx return incremental_indices.long() + padding_idx
def roberta_task_weights_filter(
all_weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1, all_weights2 = itertools.tee(all_weights)
def encoder_decoder_weights():
for name, weight in all_weights1:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment