Unverified Commit 122f75d9 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix pipeline parallel with multimodal models with the Transformers modelling backend (#37057)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d8f8a7aa
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
# limitations under the License. # limitations under the License.
"""Transformers modeling backend base class.""" """Transformers modeling backend base class."""
from collections.abc import Iterable from collections.abc import Callable, Iterable
from itertools import chain from itertools import chain
from operator import attrgetter
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import regex as re import regex as re
...@@ -296,6 +297,15 @@ class Base( ...@@ -296,6 +297,15 @@ class Base(
# Apply mapping to quantization config if needed # Apply mapping to quantization config if needed
self._maybe_apply_model_mapping() self._maybe_apply_model_mapping()
def _get_tie_word_embeddings(self):
"""
Check if the model has tied word embeddings.
"""
# Transformers v4 and v5 will store this in different places
tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False)
tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False)
return tie_word_embeddings_v4 or tie_word_embeddings_v5
def pipeline_parallel(self): def pipeline_parallel(self):
""" """
Apply the model's pipeline parallelization plan. Apply the model's pipeline parallelization plan.
...@@ -311,11 +321,22 @@ class Base( ...@@ -311,11 +321,22 @@ class Base(
f"{type(self.model)} does not support pipeline parallel. {tip}" f"{type(self.model)} does not support pipeline parallel. {tip}"
) )
def attrsetter(attr: str) -> Callable[[object, object], None]:
"""Set a possibly nested attribute, like the inverse of attrgetter."""
parent, _, name = attr.rpartition(".")
def setter(obj: object, value: object):
attr_parent = attrgetter(parent)(obj) if parent else obj
setattr(attr_parent, name, value)
return setter
module_lists = [] module_lists = []
module_list_idx = None module_list_idx = None
pp_plan = list(self.model._pp_plan.keys()) pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan): for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList): # attrgetter in case the module is nested (e.g. "text_model.layers")
if isinstance(attrgetter(name)(self.model), nn.ModuleList):
module_lists.append(name) module_lists.append(name)
module_list_idx = i module_list_idx = i
...@@ -330,11 +351,11 @@ class Base( ...@@ -330,11 +351,11 @@ class Base(
# Layers before module list # Layers before module list
for name in pp_plan[:module_list_idx]: for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or ( if self.pp_group.is_first_rank or (
getattr(self.text_config, "tie_word_embeddings", False) self._get_tie_word_embeddings() and self.pp_group.is_last_rank
and self.pp_group.is_last_rank
): ):
continue continue
setattr(self.model, name, PPMissingLayer()) # attrsetter in case the module is nested (e.g. "text_model.embed_tokens")
attrsetter(name)(self.model, PPMissingLayer())
# Module list # Module list
start_layer, end_layer = get_pp_indices( start_layer, end_layer = get_pp_indices(
...@@ -343,7 +364,8 @@ class Base( ...@@ -343,7 +364,8 @@ class Base(
self.pp_group.world_size, self.pp_group.world_size,
) )
layers_name = pp_plan[module_list_idx] layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name) # attrgetter in case the module is nested (e.g. "text_model.layers")
layers = attrgetter(layers_name)(self.model)
for i in range(len(layers)): for i in range(len(layers)):
if start_layer <= i and i < end_layer: if start_layer <= i and i < end_layer:
continue continue
...@@ -353,7 +375,8 @@ class Base( ...@@ -353,7 +375,8 @@ class Base(
for name in pp_plan[module_list_idx + 1 :]: for name in pp_plan[module_list_idx + 1 :]:
# Modules that should be on last rank # Modules that should be on last rank
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer()) # attrsetter in case the module is nested (e.g. "text_model.norm")
attrsetter(name)(self.model, PPMissingLayer())
def recursive_replace(self): def recursive_replace(self):
"""Recursively replace modules in the model as needed. """Recursively replace modules in the model as needed.
......
...@@ -38,7 +38,7 @@ class CausalMixin(VllmModelForTextGeneration): ...@@ -38,7 +38,7 @@ class CausalMixin(VllmModelForTextGeneration):
# Tell `Base.load_weights` to skip # Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings # `lm_head` if the model has tied word embeddings
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) tie_word_embeddings = self._get_tie_word_embeddings()
if tie_word_embeddings: if tie_word_embeddings:
self.skip_prefixes.append("lm_head.") self.skip_prefixes.append("lm_head.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment