"vscode:/vscode.git/clone" did not exist on "da1f7cc12a12ea4a744d26122e9a13ea4b3f4c7b"
Unverified Commit 97cfa65d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Add pipeline parallel support to `TransformersModel` (#12832)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 911c8eb0
...@@ -73,7 +73,7 @@ The Transformers fallback explicitly supports the following features: ...@@ -73,7 +73,7 @@ The Transformers fallback explicitly supports the following features:
- <project:#quantization-index> (except GGUF) - <project:#quantization-index> (except GGUF)
- <project:#lora-adapter> - <project:#lora-adapter>
- <project:#distributed-serving> (pipeline parallel coming soon <gh-pr:12832>!) - <project:#distributed-serving> (requires `transformers>=4.49.0`)
#### Remote code #### Remote code
......
...@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = { ...@@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersModel
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama # Uses Llama
...@@ -243,6 +245,7 @@ TEST_MODELS = [ ...@@ -243,6 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION] # [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct", "microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING] # [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
......
...@@ -15,21 +15,25 @@ ...@@ -15,21 +15,25 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
import re import re
from itertools import chain
from typing import Iterable, Literal, Optional, Union from typing import Iterable, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, PreTrainedModel from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
from vllm.distributed import get_tensor_model_parallel_world_size ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsQuant from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import maybe_prefix from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,7 +58,7 @@ def vllm_flash_attention_forward( ...@@ -53,7 +58,7 @@ def vllm_flash_attention_forward(
# Transformers kwargs # Transformers kwargs
scaling: Optional[float] = None, scaling: Optional[float] = None,
# vLLM kwargs # vLLM kwargs
attention_instances: Optional[list[Attention]] = None, attention_instances: Optional[dict[Attention]] = None,
**kwargs): **kwargs):
self_attn = attention_instances[module.layer_idx] self_attn = attention_instances[module.layer_idx]
if scaling is not None: if scaling is not None:
...@@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): ...@@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
def replace_linear_class( def replace_linear_class(
linear: nn.Linear, linear: nn.Linear, style: Literal["colwise", "rowwise"],
style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]: ) -> Union[ColumnParallelLinear, RowParallelLinear]:
""" """
Replace nn.Linear with one of vLLM's tensor parallel linear classes. Replace nn.Linear with one of vLLM's tensor parallel linear classes.
`quant_config` is not yet supported.
Args: Args:
linear (nn.Linear): `nn.Linear` to be replaced. linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise". style (str): Tensor parallel style of the new linear, e.g. "colwise".
...@@ -105,7 +109,7 @@ def replace_linear_class( ...@@ -105,7 +109,7 @@ def replace_linear_class(
) )
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens" embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it ] # TODO transformers will have a util to get it
...@@ -114,109 +118,246 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -114,109 +118,246 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
super().__init__() super().__init__()
logger.info("Using Transformers backend.") logger.info("Using Transformers backend.")
config = vllm_config.model_config.hf_config config: PretrainedConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config: CacheConfig = vllm_config.cache_config
model_config = vllm_config.model_config device_config: DeviceConfig = vllm_config.device_config
parallel_config = vllm_config.parallel_config model_config: ModelConfig = vllm_config.model_config
parallel_config: ParallelConfig = vllm_config.parallel_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config self.config = config
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
self.parallel_config = parallel_config
self.quant_config = quant_config
self.vocab_size = model_config.get_vocab_size() self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size() self.unpadded_vocab_size = model_config.get_vocab_size()
self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
self.config, config,
attn_implementation="vllm", attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype, torch_dtype=model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
prefix = self.model.base_model_prefix prefix = self.model.base_model_prefix
# MLP modifications self.pipeline_parallel()
self.apply_base_model_tp_plan(self.model) self.tensor_parallel()
# Attention modifications (assumes 1 attention op per hidden layer) # Input embeddings
num_heads = model_config.get_num_attention_heads(parallel_config) if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
head_size = model_config.get_head_size() self.model.set_input_embeddings(
num_kv_heads = model_config.get_num_kv_heads(parallel_config) VocabParallelEmbedding(
self.attention_instances = [ config.vocab_size,
Attention( config.hidden_size,
num_heads=num_heads, org_num_embeddings=config.vocab_size,
head_size=head_size, quant_config=quant_config,
# NOTE: We use Llama scale as default, if it's set by ))
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
]
# Model modifications # Attention layers
self.replace_vocab_embed_class(self.model) self.attention_instances = self.create_attention_instances()
# ForCausalLM modifications # Output embeddings
self.lm_head = ParallelLMHead(self.vocab_size, if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=self.quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head")) prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale) config.vocab_size,
logit_scale)
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)
self.sampler = get_sampler() self.sampler = get_sampler()
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_size <= 1:
return
if not self.model.supports_pp_plan:
raise ValueError(
f"{type(self.model)} does not support pipeline parallel yet!")
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!")
if module_list_idx is None:
raise ValueError(
f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
and self.pp_group.is_last_rank):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer(return_tuple=True)
# Layers after module list
for name in pp_plan[module_list_idx + 1:]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()
def tensor_parallel(self):
""" """
Apply the base model tensor parallelization plan to a module. Apply the model's tensor parallelization plan.
Currently only supports linear layers. Currently only supports linear layers.
""" """
if (self.config.base_model_tp_plan is None if self.tp_size > 1 and self.config.base_model_tp_plan is None:
and get_tensor_model_parallel_world_size() > 1):
raise ValueError( raise ValueError(
"Trying to run tensor parallelization but the model does not " f"{type(self.model)} does not support tensor parallel yet!")
"support it yet!")
tp_plan = self.model._tp_plan
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name) qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items(): for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance( if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear): child_module, nn.Linear):
new_module = replace_linear_class(child_module, style, new_module = replace_linear_class(
self.quant_config) child_module, style, self.quant_config)
setattr(module, child_name, new_module) setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module) log_replacement(qual_name, child_module, new_module)
else: else:
self.apply_base_model_tp_plan(child_module, prefix=qual_name) _tensor_parallel(child_module, prefix=qual_name)
def replace_vocab_embed_class(self, module: nn.Module): _tensor_parallel(self.model)
# Use native set input embeddings
new_module = VocabParallelEmbedding( def create_attention_instances(self) -> dict[int, Attention]:
self.vocab_size, """
self.config.hidden_size, Create `Attention` instances to inform KV cache allocation.
org_num_embeddings=self.vocab_size, """
quant_config=None, num_heads = self.model_config.get_num_attention_heads(
) self.parallel_config)
log_replacement("input embedding", self.model.get_input_embeddings(), head_size = self.model_config.get_head_size()
new_module) num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
module.set_input_embeddings(new_module) start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn")
for i in range(start, end)
}
def init_buffers(self, module: nn.Module):
"""
If a `buffer` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
This means that:
- `type(module)` is a class from `transformers`
- This class is constructed using a `PretrainedConfig`
"""
for name, buffer in module.named_buffers(recurse=False):
if buffer.device == torch.device("meta"):
new_buffer = getattr(type(module)(self.config), name)
setattr(module, name, new_buffer)
for child in module.children():
self.init_buffers(child)
def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
module.to_empty(device=self.device_config.device)
return # We can stop recursing because to_empty is recursive
for child in module.children():
self.meta_to_empty(child)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model( if not get_pp_group().is_first_rank:
input_ids[None, ...], assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False, use_cache=False,
position_ids=positions[None, ...], position_ids=positions[None, ...],
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now return_dict=False)[0][0, ...] # we remove batch dimension for now
return model_output
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def compute_logits( def compute_logits(
self, self,
...@@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set[str]() loaded_params = set[str]()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name not in params_dict: # Necessary for some models which use remote code
name = f"{self.model.base_model_prefix}.{name}" if not name.startswith(prefix := self.model.base_model_prefix):
name = maybe_prefix(prefix, name)
if is_pp_missing_parameter(name, self):
continue
if name in params_dict: if name in params_dict:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
...@@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity): ...@@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input, ) if self.return_tuple else input
_CPU_OFFLOAD_BYTES = 0 _CPU_OFFLOAD_BYTES = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment