Unverified Commit dcb6dd9b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Context Parallel w/ Ring & Ulysses & Unified Attention (#11941)



* update

* update

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

* update

* refactor

* refactor; support flash attention 2 with cp

* fix

* support sage attention with cp

* make torch compile compatible

* update

* refactor

* update

* refactor

* refactor

* add ulysses backward

* try to make dreambooth script work; accelerator backward not playing well

* Revert "try to make dreambooth script work; accelerator backward not playing well"

This reverts commit 768d0ea6fa6a305d12df1feda2afae3ec80aa449.

* workaround compilation problems with triton when doing all-to-all

* support wan

* handle backward correctly

* support qwen

* support ltx

* make fix-copies

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* apply review suggestions

* update docs

* add explanation

* make fix-copies

* add docstrings

* support passing parallel_config to from_pretrained

* apply review suggestions

* make style

* update

* Update docs/source/en/api/parallel.md
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* up

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 043ab252
......@@ -70,6 +70,8 @@
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compiling and offloading quantized models
- local: api/parallel
title: Parallel inference
- title: Community optimizations
sections:
- local: optimization/pruna
......
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Parallelism
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
## ParallelConfig
[[autodoc]] ParallelConfig
## ContextParallelConfig
[[autodoc]] ContextParallelConfig
[[autodoc]] hooks.apply_context_parallel
......@@ -202,6 +202,7 @@ else:
"CogView4Transformer2DModel",
"ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
"ContextParallelConfig",
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
......@@ -229,6 +230,7 @@ else:
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"QwenImageControlNetModel",
......@@ -888,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
ContextParallelConfig,
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
......@@ -915,6 +918,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageControlNetModel,
......
......@@ -16,6 +16,7 @@ from ..utils import is_torch_available
if is_torch_available():
from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
......
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Dict, List, Type, Union
import torch
import torch.distributed._functional_collectives as funcol
from ..models._modeling_parallel import (
ContextParallelConfig,
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
)
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: Dict[str, int] = None
_cls: Type = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier], True, None
if self.cached_parameter_indices is not None:
index = self.cached_parameter_indices.get(identifier, None)
if index is None:
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
return args[index], False, index
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self.cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self.cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index], False, index
def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
for m in submodule:
if isinstance(cp_model_plan, dict):
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
if isinstance(cp_model_plan, ContextParallelOutput):
cp_model_plan = [cp_model_plan]
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]
for m in submodule:
registry = HookRegistry.check_if_exists_or_initialize(m)
if isinstance(cp_model_plan, dict):
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry.remove_hook(hook_name)
class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
self.module_forward_metadata = None
def initialize_hook(self, module):
cls = unwrap_module(module).__class__
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
return module
def pre_forward(self, module, *args, **kwargs):
args_list = list(args)
for name, cpm in self.metadata.items():
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
continue
# Maybe the parameter was passed as a keyword argument
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
name, args_list, kwargs
)
if input_val is None:
continue
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
)
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
sharded_input_val.append(x)
input_val = sharded_input_val
else:
raise ValueError(f"Unsupported input type: {type(input_val)}")
if is_kwarg:
kwargs[name] = input_val
elif index is not None and index < len(args_list):
args_list[index] = input_val
else:
raise ValueError(
f"An unexpected error occurred while processing the input '{name}'. Please open an "
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
f"example along with the full stack trace."
)
return tuple(args_list), kwargs
def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
if not is_tensor and not is_tensor_list:
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = [output] if is_tensor else list(output)
for index, cpm in self.metadata.items():
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
continue
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
output[index] = current_output
return output[0] if is_tensor else tuple(output)
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = list(output)
if len(output) != len(self.metadata):
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
return output[0] if is_tensor else tuple(output)
class AllGatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group):
ctx.dim = dim
ctx.group = group
ctx.world_size = torch.distributed.get_world_size(group)
ctx.rank = torch.distributed.get_rank(group)
return funcol.all_gather_tensor(tensor, dim, group=group)
@staticmethod
def backward(ctx, grad_output):
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
return grad_chunks[ctx.rank], None, None
class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
# because the alternate case has not yet been tested/required for any model.
assert tensor.size()[dim] % mesh.size() == 0, (
"Tensor size along dimension to be sharded must be divisible by mesh size"
)
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
return tensor
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
if first_atom == "*":
if not isinstance(model, torch.nn.ModuleList):
raise ValueError("Wildcard '*' can only be used with ModuleList")
submodules = []
for submodule in model:
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
if not isinstance(subsubmodules, list):
subsubmodules = [subsubmodules]
submodules.extend(subsubmodules)
return submodules
else:
if hasattr(model, first_atom):
submodule = getattr(model, first_atom)
return _find_submodule_by_name(submodule, remaining_name)
else:
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
......@@ -25,6 +25,7 @@ from ..utils import (
_import_structure = {}
if is_torch_available():
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
......@@ -119,6 +120,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
......
# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
# Experimental changes are subject to change and APIs may break without warning.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
from ..utils import get_logger
if TYPE_CHECKING:
pass
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): add support for the following:
# - Unified Attention
# - More dispatcher attention backends
# - CFG/Data Parallel
# - Tensor Parallel
@dataclass
class ContextParallelConfig:
"""
Configuration for context parallelism.
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
"""
ring_degree: Optional[int] = None
ulysses_degree: Optional[int] = None
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
_rank: int = None
_world_size: int = None
_device: torch.device = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
def __post_init__(self):
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
class ParallelConfig:
"""
Configuration for applying different parallelisms.
Args:
context_parallel_config (`ContextParallelConfig`, *optional*):
Configuration for context parallelism.
"""
context_parallel_config: Optional[ContextParallelConfig] = None
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
rank: int,
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
@dataclass(frozen=True)
class ContextParallelInput:
"""
Configuration for splitting an input tensor across context parallel region.
Args:
split_dim (`int`):
The dimension along which to split the tensor.
expected_dims (`int`, *optional*):
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
tensor has the expected number of dimensions before splitting.
split_output (`bool`, *optional*, defaults to `False`):
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
RoPE).
"""
split_dim: int
expected_dims: Optional[int] = None
split_output: bool = False
def __repr__(self):
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
@dataclass(frozen=True)
class ContextParallelOutput:
"""
Configuration for gathering an output tensor across context parallel region.
Args:
gather_dim (`int`):
The dimension along which to gather the tensor.
expected_dims (`int`, *optional*):
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
tensor has the expected number of dimensions before gathering.
"""
gather_dim: int
expected_dims: Optional[int] = None
def __repr__(self):
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
# A dictionary where keys denote the input to be split across context parallel region, and the
# value denotes the sharding configuration.
# If the key is a string, it denotes the name of the parameter in the forward function.
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
# to be split across context parallel region.
ContextParallelInputType = Dict[
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
]
# A dictionary where keys denote the output to be gathered across context parallel region, and the
# value denotes the gathering configuration.
ContextParallelOutputType = Union[
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
]
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
#
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
# tensors at different stages of the forward:
#
# ```python
# _cp_plan = {
# "": {
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
# },
# "pos_embed": {
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
# },
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
# }
# ```
#
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
# split/gathered according to this at the respective module level. Here, the following happens:
# - "":
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
# - "pos_embed":
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
# we can individually specify how they should be split
# - "proj_out":
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
# layer forward has run).
#
# ContextParallelInput:
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
#
# ContextParallelOutput:
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
This diff is collapsed.
......@@ -65,6 +65,7 @@ from ..utils.hub_utils import (
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
......@@ -248,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []
_parallel_config = None
_cp_plan = None
def __init__(self):
super().__init__()
......@@ -620,8 +623,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
set, or the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
......@@ -960,6 +963,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage:
......@@ -1340,6 +1344,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if parallel_config is not None:
model.enable_parallelism(config=parallel_config)
if output_loading_info:
return model, loading_info
......@@ -1478,6 +1485,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
def enable_parallelism(
self,
*,
config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
)
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
if config.context_parallel_config is not None:
apply_context_parallel(self, config.context_parallel_config, cp_plan)
self._parallel_config = config
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_parallel_config"):
continue
processor._parallel_config = config
@classmethod
def _load_pretrained_model(
cls,
......
......@@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed(
class BriaAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
......@@ -161,7 +162,12 @@ class BriaAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
......
......@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
......@@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
class FluxAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
......@@ -114,7 +116,12 @@ class FluxAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
......@@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
_attention_backend = None
_parallel_config = None
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
......@@ -220,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
......@@ -252,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
......@@ -556,6 +566,15 @@ class FluxTransformer2DModel(
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
......
......@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
......@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor:
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
......@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
......@@ -409,6 +412,18 @@ class LTXVideoTransformer3DModel(
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"rope": {
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
......
......@@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
......@@ -261,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0:
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
......@@ -334,6 +336,7 @@ class QwenDoubleStreamAttnProcessor2_0:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape back
......@@ -502,6 +505,18 @@ class QwenImageTransformer2DModel(
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
......
......@@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states
class SkyReelsV2AttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
......@@ -139,6 +140,7 @@ class SkyReelsV2AttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
......@@ -151,6 +153,7 @@ class SkyReelsV2AttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
......
......@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
......@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
class WanAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
......@@ -132,6 +134,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
......@@ -144,6 +147,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
......@@ -539,6 +543,19 @@ class WanTransformer3DModel(
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
_cp_plan = {
"rope": {
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
},
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
......
......@@ -648,6 +648,21 @@ class ConsistencyDecoderVAE(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ContextParallelConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
......@@ -1053,6 +1068,21 @@ class OmniGenTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ParallelConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment