Unverified Commit dcb6dd9b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Context Parallel w/ Ring & Ulysses & Unified Attention (#11941)



* update

* update

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

* update

* refactor

* refactor; support flash attention 2 with cp

* fix

* support sage attention with cp

* make torch compile compatible

* update

* refactor

* update

* refactor

* refactor

* add ulysses backward

* try to make dreambooth script work; accelerator backward not playing well

* Revert "try to make dreambooth script work; accelerator backward not playing well"

This reverts commit 768d0ea6fa6a305d12df1feda2afae3ec80aa449.

* workaround compilation problems with triton when doing all-to-all

* support wan

* handle backward correctly

* support qwen

* support ltx

* make fix-copies

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* apply review suggestions

* update docs

* add explanation

* make fix-copies

* add docstrings

* support passing parallel_config to from_pretrained

* apply review suggestions

* make style

* update

* Update docs/source/en/api/parallel.md
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* up

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 043ab252
...@@ -70,6 +70,8 @@ ...@@ -70,6 +70,8 @@
title: Reduce memory usage title: Reduce memory usage
- local: optimization/speed-memory-optims - local: optimization/speed-memory-optims
title: Compiling and offloading quantized models title: Compiling and offloading quantized models
- local: api/parallel
title: Parallel inference
- title: Community optimizations - title: Community optimizations
sections: sections:
- local: optimization/pruna - local: optimization/pruna
......
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# Parallelism
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
## ParallelConfig
[[autodoc]] ParallelConfig
## ContextParallelConfig
[[autodoc]] ContextParallelConfig
[[autodoc]] hooks.apply_context_parallel
...@@ -202,6 +202,7 @@ else: ...@@ -202,6 +202,7 @@ else:
"CogView4Transformer2DModel", "CogView4Transformer2DModel",
"ConsisIDTransformer3DModel", "ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE", "ConsistencyDecoderVAE",
"ContextParallelConfig",
"ControlNetModel", "ControlNetModel",
"ControlNetUnionModel", "ControlNetUnionModel",
"ControlNetXSAdapter", "ControlNetXSAdapter",
...@@ -229,6 +230,7 @@ else: ...@@ -229,6 +230,7 @@ else:
"MultiAdapter", "MultiAdapter",
"MultiControlNetModel", "MultiControlNetModel",
"OmniGenTransformer2DModel", "OmniGenTransformer2DModel",
"ParallelConfig",
"PixArtTransformer2DModel", "PixArtTransformer2DModel",
"PriorTransformer", "PriorTransformer",
"QwenImageControlNetModel", "QwenImageControlNetModel",
...@@ -888,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -888,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Transformer2DModel, CogView4Transformer2DModel,
ConsisIDTransformer3DModel, ConsisIDTransformer3DModel,
ConsistencyDecoderVAE, ConsistencyDecoderVAE,
ContextParallelConfig,
ControlNetModel, ControlNetModel,
ControlNetUnionModel, ControlNetUnionModel,
ControlNetXSAdapter, ControlNetXSAdapter,
...@@ -915,6 +918,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -915,6 +918,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter, MultiAdapter,
MultiControlNetModel, MultiControlNetModel,
OmniGenTransformer2DModel, OmniGenTransformer2DModel,
ParallelConfig,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
QwenImageControlNetModel, QwenImageControlNetModel,
......
...@@ -16,6 +16,7 @@ from ..utils import is_torch_available ...@@ -16,6 +16,7 @@ from ..utils import is_torch_available
if is_torch_available(): if is_torch_available():
from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading from .group_offloading import apply_group_offloading
......
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Dict, List, Type, Union
import torch
import torch.distributed._functional_collectives as funcol
from ..models._modeling_parallel import (
ContextParallelConfig,
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
)
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@dataclass
class ModuleForwardMetadata:
cached_parameter_indices: Dict[str, int] = None
_cls: Type = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier], True, None
if self.cached_parameter_indices is not None:
index = self.cached_parameter_indices.get(identifier, None)
if index is None:
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
return args[index], False, index
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self.cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self.cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index], False, index
def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
plan: Dict[str, ContextParallelModelPlan],
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
for m in submodule:
if isinstance(cp_model_plan, dict):
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
if isinstance(cp_model_plan, ContextParallelOutput):
cp_model_plan = [cp_model_plan]
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry = HookRegistry.check_if_exists_or_initialize(m)
registry.register_hook(hook, hook_name)
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
for module_id, cp_model_plan in plan.items():
submodule = _get_submodule_by_name(module, module_id)
if not isinstance(submodule, list):
submodule = [submodule]
for m in submodule:
registry = HookRegistry.check_if_exists_or_initialize(m)
if isinstance(cp_model_plan, dict):
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
else:
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
registry.remove_hook(hook_name)
class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
self.module_forward_metadata = None
def initialize_hook(self, module):
cls = unwrap_module(module).__class__
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
return module
def pre_forward(self, module, *args, **kwargs):
args_list = list(args)
for name, cpm in self.metadata.items():
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
continue
# Maybe the parameter was passed as a keyword argument
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
name, args_list, kwargs
)
if input_val is None:
continue
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
)
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
sharded_input_val.append(x)
input_val = sharded_input_val
else:
raise ValueError(f"Unsupported input type: {type(input_val)}")
if is_kwarg:
kwargs[name] = input_val
elif index is not None and index < len(args_list):
args_list[index] = input_val
else:
raise ValueError(
f"An unexpected error occurred while processing the input '{name}'. Please open an "
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
f"example along with the full stack trace."
)
return tuple(args_list), kwargs
def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
if not is_tensor and not is_tensor_list:
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = [output] if is_tensor else list(output)
for index, cpm in self.metadata.items():
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
continue
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
output[index] = current_output
return output[0] if is_tensor else tuple(output)
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
super().__init__()
self.metadata = metadata
self.parallel_config = parallel_config
def post_forward(self, module, output):
is_tensor = isinstance(output, torch.Tensor)
if is_tensor:
output = [output]
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = list(output)
if len(output) != len(self.metadata):
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
return output[0] if is_tensor else tuple(output)
class AllGatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group):
ctx.dim = dim
ctx.group = group
ctx.world_size = torch.distributed.get_world_size(group)
ctx.rank = torch.distributed.get_rank(group)
return funcol.all_gather_tensor(tensor, dim, group=group)
@staticmethod
def backward(ctx, grad_output):
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
return grad_chunks[ctx.rank], None, None
class EquipartitionSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
# because the alternate case has not yet been tested/required for any model.
assert tensor.size()[dim] % mesh.size() == 0, (
"Tensor size along dimension to be sharded must be divisible by mesh size"
)
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
return tensor
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
return _find_submodule_by_name(model, name)
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name == "":
return model
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
if first_atom == "*":
if not isinstance(model, torch.nn.ModuleList):
raise ValueError("Wildcard '*' can only be used with ModuleList")
submodules = []
for submodule in model:
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
if not isinstance(subsubmodules, list):
subsubmodules = [subsubmodules]
submodules.extend(subsubmodules)
return submodules
else:
if hasattr(model, first_atom):
submodule = getattr(model, first_atom)
return _find_submodule_by_name(submodule, remaining_name)
else:
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
...@@ -25,6 +25,7 @@ from ..utils import ( ...@@ -25,6 +25,7 @@ from ..utils import (
_import_structure = {} _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"] _import_structure["auto_model"] = ["AutoModel"]
...@@ -119,6 +120,7 @@ if is_flax_available(): ...@@ -119,6 +120,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available(): if is_torch_available():
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
from .adapter import MultiAdapter, T2IAdapter from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel from .auto_model import AutoModel
......
# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
# Experimental changes are subject to change and APIs may break without warning.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
from ..utils import get_logger
if TYPE_CHECKING:
pass
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): add support for the following:
# - Unified Attention
# - More dispatcher attention backends
# - CFG/Data Parallel
# - Tensor Parallel
@dataclass
class ContextParallelConfig:
"""
Configuration for context parallelism.
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
"""
ring_degree: Optional[int] = None
ulysses_degree: Optional[int] = None
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
_rank: int = None
_world_size: int = None
_device: torch.device = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
def __post_init__(self):
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
class ParallelConfig:
"""
Configuration for applying different parallelisms.
Args:
context_parallel_config (`ContextParallelConfig`, *optional*):
Configuration for context parallelism.
"""
context_parallel_config: Optional[ContextParallelConfig] = None
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
rank: int,
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
@dataclass(frozen=True)
class ContextParallelInput:
"""
Configuration for splitting an input tensor across context parallel region.
Args:
split_dim (`int`):
The dimension along which to split the tensor.
expected_dims (`int`, *optional*):
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
tensor has the expected number of dimensions before splitting.
split_output (`bool`, *optional*, defaults to `False`):
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
RoPE).
"""
split_dim: int
expected_dims: Optional[int] = None
split_output: bool = False
def __repr__(self):
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
@dataclass(frozen=True)
class ContextParallelOutput:
"""
Configuration for gathering an output tensor across context parallel region.
Args:
gather_dim (`int`):
The dimension along which to gather the tensor.
expected_dims (`int`, *optional*):
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
tensor has the expected number of dimensions before gathering.
"""
gather_dim: int
expected_dims: Optional[int] = None
def __repr__(self):
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
# A dictionary where keys denote the input to be split across context parallel region, and the
# value denotes the sharding configuration.
# If the key is a string, it denotes the name of the parameter in the forward function.
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
# to be split across context parallel region.
ContextParallelInputType = Dict[
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
]
# A dictionary where keys denote the output to be gathered across context parallel region, and the
# value denotes the gathering configuration.
ContextParallelOutputType = Union[
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
]
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
#
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
# tensors at different stages of the forward:
#
# ```python
# _cp_plan = {
# "": {
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
# },
# "pos_embed": {
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
# },
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
# }
# ```
#
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
# split/gathered according to this at the respective module level. Here, the following happens:
# - "":
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
# - "pos_embed":
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
# we can individually specify how they should be split
# - "proj_out":
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
# layer forward has run).
#
# ContextParallelInput:
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
#
# ContextParallelOutput:
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
...@@ -17,9 +17,10 @@ import functools ...@@ -17,9 +17,10 @@ import functools
import inspect import inspect
import math import math
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch import torch
import torch.distributed._functional_collectives as funcol
from ..utils import ( from ..utils import (
get_logger, get_logger,
...@@ -39,6 +40,9 @@ from ..utils import ( ...@@ -39,6 +40,9 @@ from ..utils import (
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_FLEX_VERSION = "2.5.0"
...@@ -56,9 +60,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _ ...@@ -56,9 +60,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN: if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else: else:
flash_attn_func = None flash_attn_func = None
flash_attn_varlen_func = None flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
if _CAN_USE_FLASH_ATTN_3: if _CAN_USE_FLASH_ATTN_3:
...@@ -197,17 +204,24 @@ class _AttentionBackendRegistry: ...@@ -197,17 +204,24 @@ class _AttentionBackendRegistry:
_backends = {} _backends = {}
_constraints = {} _constraints = {}
_supported_arg_names = {} _supported_arg_names = {}
_supports_context_parallel = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS _checks_enabled = DIFFUSERS_ATTN_CHECKS
@classmethod @classmethod
def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): def register(
cls,
backend: AttentionBackendName,
constraints: Optional[List[Callable]] = None,
supports_context_parallel: bool = False,
):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
def decorator(func): def decorator(func):
cls._backends[backend] = func cls._backends[backend] = func
cls._constraints[backend] = constraints or [] cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel
return func return func
return decorator return decorator
...@@ -220,6 +234,17 @@ class _AttentionBackendRegistry: ...@@ -220,6 +234,17 @@ class _AttentionBackendRegistry:
def list_backends(cls): def list_backends(cls):
return list(cls._backends.keys()) return list(cls._backends.keys())
@classmethod
def _is_context_parallel_enabled(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
@contextlib.contextmanager @contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
...@@ -253,6 +278,7 @@ def dispatch_attention_fn( ...@@ -253,6 +278,7 @@ def dispatch_attention_fn(
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
*, *,
backend: Optional[AttentionBackendName] = None, backend: Optional[AttentionBackendName] = None,
parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attention_kwargs = attention_kwargs or {} attention_kwargs = attention_kwargs or {}
...@@ -264,6 +290,14 @@ def dispatch_attention_fn( ...@@ -264,6 +290,14 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend) backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name) backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
backend_name, parallel_config
):
raise ValueError(
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)
kwargs = { kwargs = {
"query": query, "query": query,
"key": key, "key": key,
...@@ -273,6 +307,7 @@ def dispatch_attention_fn( ...@@ -273,6 +307,7 @@ def dispatch_attention_fn(
"is_causal": is_causal, "is_causal": is_causal,
"scale": scale, "scale": scale,
**attention_kwargs, **attention_kwargs,
"_parallel_config": parallel_config,
} }
if is_torch_version(">=", "2.5.0"): if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa kwargs["enable_gqa"] = enable_gqa
...@@ -521,22 +556,621 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): ...@@ -521,22 +556,621 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# Registrations are required for fullgraph tracing compatibility # Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3(
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") q: torch.Tensor,
def _wrapped_flash_attn_3_original( k: torch.Tensor,
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
attention_chunk: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_func(query, key, value) # Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
causal=causal,
qv=qv,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
)
lse = lse.permute(0, 2, 1) lse = lse.permute(0, 2, 1)
return out, lse return out, lse
@_register_fake("flash_attn_3::_flash_attn_forward") @_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _(
batch_size, seq_len, num_heads, head_dim = query.shape q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
qv: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
attention_chunk: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
window_size = (-1, -1) # noqa: F841
# A lot of the parameters here are not yet used in any way within diffusers.
# We can safely ignore for now and keep the fake op shape propagation simple.
batch_size, seq_len, num_heads, head_dim = q.shape
lse_shape = (batch_size, seq_len, num_heads) lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape) return torch.empty_like(q), q.new_empty(lse_shape)
# ===== Helper functions to use attention backends with templated CP autograd functions =====
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
# forward declaration:
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
def _cudnn_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
tensors_to_save = ()
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
# if the input tensors are not contiguous.
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
tensors_to_save += (query, key, value)
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=attn_mask,
compute_log_sumexp=return_lse,
dropout_p=dropout_p,
is_causal=is_causal,
return_debug_mask=False,
scale=scale,
)
)
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
if _save_ctx:
ctx.save_for_backward(*tensors_to_save)
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.scale = scale
ctx.attn_mask = attn_mask
ctx.max_q = max_q
ctx.max_k = max_k
out = out.transpose(1, 2).contiguous()
if lse is not None:
lse = lse.transpose(1, 2).contiguous()
return (out, lse) if return_lse else out
# backward declaration:
# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
def _cudnn_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
grad_out = grad_out.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
logsumexp=lse,
philox_seed=philox_seed,
philox_offset=philox_offset,
attn_bias=ctx.attn_mask,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=ctx.max_q,
max_k=ctx.max_k,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
)
grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
def _flash_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
# Hardcoded for now
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if scale is None:
scale = query.shape[-1] ** (-0.5)
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1)
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
lse_d = _wrapped_flash_attn_backward( # noqa: F841
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
# Head dimension may have been padded
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
out = sageattn(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1)
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
):
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
# ===== Context parallel =====
# Reference:
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
def _wait_tensor(tensor):
if isinstance(tensor, funcol.AsyncCollectiveTensor):
tensor = tensor.wait()
return tensor
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
shape = x.shape
# HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
# to benchmark triton codegen fails somewhere:
# buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
# ValueError: Tensors must be contiguous
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
x = x.reshape(shape)
x = _wait_tensor(x)
return x
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
rank = _parallel_config.context_parallel_config._ring_local_rank
world_size = _parallel_config.context_parallel_config.ring_degree
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
ctx.forward_op = forward_op
ctx.backward_op = backward_op
ctx.q_shape = query.shape
ctx.kv_shape = key.shape
ctx._parallel_config = _parallel_config
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key_numel = key.numel()
key = kv[:key_numel].reshape_as(key)
value = kv[key_numel:].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = forward_op(
ctx,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
True,
_save_ctx=i == 0,
_parallel_config=_parallel_config,
)
if _parallel_config.context_parallel_config.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
):
ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh
rank = ctx._parallel_config.context_parallel_config._ring_local_rank
world_size = ctx._parallel_config.context_parallel_config.ring_degree
next_rank = (rank + 1) % world_size
next_ranks = list(range(1, world_size)) + [0]
accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype
grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device)
grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
next_grad_kv = None
query, key, value, *_ = ctx.saved_tensors
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key_numel = key.numel()
key = kv[:key_numel].reshape_as(key)
value = kv[key_numel:].reshape_as(value)
next_rank = (next_rank + 1) % world_size
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
if i > 0:
grad_kv_buffer = _wait_tensor(next_grad_kv)
grad_key_numel = grad_key.numel()
grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
grad_query += grad_query_op
grad_key += grad_key_op
grad_value += grad_value_op
if i < world_size - 1:
grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
world_size = _parallel_config.context_parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
ctx.forward_op = forward_op
ctx.backward_op = backward_op
ctx._parallel_config = _parallel_config
B, S_Q_LOCAL, H, D = query.shape
_, S_KV_LOCAL, _, _ = key.shape
H_LOCAL = H // world_size
query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
out = forward_op(
ctx,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
_save_ctx=True,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse, *_ = out
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
if return_lse:
lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
lse = _all_to_all_single(lse, group)
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
else:
lse = None
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
):
ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh
world_size = ctx._parallel_config.context_parallel_config.ulysses_degree
group = ulysses_mesh.get_group()
B, S_LOCAL, H, D = grad_out.shape
H_LOCAL = H // world_size
grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
grad_out = _all_to_all_single(grad_out, group)
grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
grad_query, grad_key, grad_value = (
x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
for x in (grad_query_op, grad_key_op, grad_value_op)
)
grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value))
grad_query, grad_key, grad_value = (
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
def _templated_context_parallel_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
*,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("Attention mask is not yet supported for templated attention.")
if is_causal:
raise ValueError("Causal attention is not yet supported for templated attention.")
if enable_gqa:
raise ValueError("GQA is not yet supported for templated attention.")
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
return TemplatedUlyssesAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
# ===== Attention backends ===== # ===== Attention backends =====
...@@ -545,34 +1179,50 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc ...@@ -545,34 +1179,50 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName.FLASH, AttentionBackendName.FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
) )
def _flash_attention( def _flash_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False, is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1), scale: Optional[float] = None,
softcap: float = 0.0, return_lse: bool = False,
alibi_slopes: Optional[torch.Tensor] = None, _parallel_config: Optional["ParallelConfig"] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
out = flash_attn_func( lse = None
q=query, if _parallel_config is None:
k=key, out = flash_attn_func(
v=value, q=query,
dropout_p=dropout_p, k=key,
softmax_scale=scale, v=value,
causal=is_causal, dropout_p=dropout_p,
window_size=window_size, softmax_scale=scale,
softcap=softcap, causal=is_causal,
alibi_slopes=alibi_slopes, return_attn_probs=return_lse,
deterministic=deterministic, )
return_attn_probs=return_attn_probs, if return_lse:
) out, lse, *_ = out
return out else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_forward_op,
backward_op=_flash_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
...@@ -583,19 +1233,12 @@ def _flash_varlen_attention( ...@@ -583,19 +1233,12 @@ def _flash_varlen_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: Optional[float] = None, scale: Optional[float] = None,
is_causal: bool = False, is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1), return_lse: bool = False,
softcap: float = 0.0, _parallel_config: Optional["ParallelConfig"] = None,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape _, seq_len_kv, _, _ = key.shape
...@@ -603,16 +1246,11 @@ def _flash_varlen_attention( ...@@ -603,16 +1246,11 @@ def _flash_varlen_attention(
if attn_mask is not None: if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen(
_prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
) )
else: )
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], [] key_valid, value_valid = [], []
for b in range(batch_size): for b in range(batch_size):
...@@ -635,11 +1273,7 @@ def _flash_varlen_attention( ...@@ -635,11 +1273,7 @@ def _flash_varlen_attention(
dropout_p=dropout_p, dropout_p=dropout_p,
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
window_size=window_size, return_attn_probs=return_lse,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
) )
out = out.unflatten(0, (batch_size, -1)) out = out.unflatten(0, (batch_size, -1))
...@@ -656,30 +1290,17 @@ def _flash_attention_3( ...@@ -656,30 +1290,17 @@ def _flash_attention_3(
value: torch.Tensor, value: torch.Tensor,
scale: Optional[float] = None, scale: Optional[float] = None,
is_causal: bool = False, is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1), return_lse: bool = False,
softcap: float = 0.0, _parallel_config: Optional["ParallelConfig"] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
out, lse, *_ = flash_attn_3_func( out, lse = _wrapped_flash_attn_3(
q=query, q=query,
k=key, k=key,
v=value, v=value,
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
) )
return (out, lse) if return_attn_probs else out return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
...@@ -696,6 +1317,7 @@ def _flash_attention_3_hub( ...@@ -696,6 +1317,7 @@ def _flash_attention_3_hub(
softcap: float = 0.0, softcap: float = 0.0,
deterministic: bool = False, deterministic: bool = False,
return_attn_probs: bool = False, return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
out = flash_attn_3_func_hub( out = flash_attn_3_func_hub(
q=query, q=query,
...@@ -728,17 +1350,11 @@ def _flash_varlen_attention_3( ...@@ -728,17 +1350,11 @@ def _flash_varlen_attention_3(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
scale: Optional[float] = None, scale: Optional[float] = None,
is_causal: bool = False, is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1), return_lse: bool = False,
softcap: float = 0.0, _parallel_config: Optional["ParallelConfig"] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape _, seq_len_kv, _, _ = key.shape
...@@ -746,16 +1362,11 @@ def _flash_varlen_attention_3( ...@@ -746,16 +1362,11 @@ def _flash_varlen_attention_3(
if attn_mask is not None: if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen(
_prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
) )
else: )
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], [] key_valid, value_valid = [], []
for b in range(batch_size): for b in range(batch_size):
...@@ -775,24 +1386,12 @@ def _flash_varlen_attention_3( ...@@ -775,24 +1386,12 @@ def _flash_varlen_attention_3(
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k, max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
) )
out = out.unflatten(0, (batch_size, -1)) out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_attn_probs else out return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
...@@ -808,7 +1407,7 @@ def _native_flex_attention( ...@@ -808,7 +1407,7 @@ def _native_flex_attention(
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False, return_lse: bool = False,
kernel_options: Optional[Dict[str, Any]] = None, _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation? # TODO: should we LRU cache the block mask creation?
score_mod = None score_mod = None
...@@ -853,7 +1452,6 @@ def _native_flex_attention( ...@@ -853,7 +1452,6 @@ def _native_flex_attention(
scale=scale, scale=scale,
enable_gqa=enable_gqa, enable_gqa=enable_gqa,
return_lse=return_lse, return_lse=return_lse,
kernel_options=kernel_options,
) )
out = out.permute(0, 2, 1, 3) out = out.permute(0, 2, 1, 3)
return out return out
...@@ -872,7 +1470,11 @@ def _native_attention( ...@@ -872,7 +1470,11 @@ def _native_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(
query=query, query=query,
...@@ -891,6 +1493,7 @@ def _native_attention( ...@@ -891,6 +1493,7 @@ def _native_attention(
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN, AttentionBackendName._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
) )
def _native_cudnn_attention( def _native_cudnn_attention(
query: torch.Tensor, query: torch.Tensor,
...@@ -901,21 +1504,43 @@ def _native_cudnn_attention( ...@@ -901,21 +1504,43 @@ def _native_cudnn_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) lse = None
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): if _parallel_config is None and not return_lse:
out = torch.nn.functional.scaled_dot_product_attention( query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
query=query, with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
key=key, out = torch.nn.functional.scaled_dot_product_attention(
value=value, query=query,
attn_mask=attn_mask, key=key,
dropout_p=dropout_p, value=value,
is_causal=is_causal, attn_mask=attn_mask,
scale=scale, dropout_p=dropout_p,
enable_gqa=enable_gqa, is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
else:
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op=_cudnn_attention_forward_op,
backward_op=_cudnn_attention_backward_op,
_parallel_config=_parallel_config,
) )
out = out.permute(0, 2, 1, 3) if return_lse:
return out out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
...@@ -931,7 +1556,11 @@ def _native_efficient_attention( ...@@ -931,7 +1556,11 @@ def _native_efficient_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(
...@@ -960,7 +1589,11 @@ def _native_flash_attention( ...@@ -960,7 +1589,11 @@ def _native_flash_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(
...@@ -990,7 +1623,11 @@ def _native_math_attention( ...@@ -990,7 +1623,11 @@ def _native_math_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("Native math attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(
...@@ -1017,7 +1654,11 @@ def _native_npu_attention( ...@@ -1017,7 +1654,11 @@ def _native_npu_attention(
value: torch.Tensor, value: torch.Tensor,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: Optional[float] = None, scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention( out = npu_fusion_attention(
query, query,
...@@ -1047,7 +1688,11 @@ def _native_xla_attention( ...@@ -1047,7 +1688,11 @@ def _native_xla_attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
is_causal: bool = False, is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query = query / math.sqrt(query.shape[-1]) query = query / math.sqrt(query.shape[-1])
out = xla_flash_attention( out = xla_flash_attention(
...@@ -1063,6 +1708,7 @@ def _native_xla_attention( ...@@ -1063,6 +1708,7 @@ def _native_xla_attention(
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
AttentionBackendName.SAGE, AttentionBackendName.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
) )
def _sage_attention( def _sage_attention(
query: torch.Tensor, query: torch.Tensor,
...@@ -1071,16 +1717,40 @@ def _sage_attention( ...@@ -1071,16 +1717,40 @@ def _sage_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
return_lse: bool = False, return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return sageattn( lse = None
q=query, if _parallel_config is None:
k=key, out = sageattn(
v=value, q=query,
tensor_layout="NHD", k=key,
is_causal=is_causal, v=value,
sm_scale=scale, tensor_layout="NHD",
return_lse=return_lse, is_causal=is_causal,
) sm_scale=scale,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register( @_AttentionBackendRegistry.register(
...@@ -1091,31 +1761,26 @@ def _sage_varlen_attention( ...@@ -1091,31 +1761,26 @@ def _sage_varlen_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
smooth_k: bool = True, return_lse: bool = False,
attn_mask: Optional[torch.Tensor] = None, _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
batch_size, seq_len_q, _, _ = query.shape batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape _, seq_len_kv, _, _ = key.shape
if attn_mask is not None: if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen(
_prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
) )
else: )
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], [] key_valid, value_valid = [], []
for b in range(batch_size): for b in range(batch_size):
...@@ -1137,7 +1802,6 @@ def _sage_varlen_attention( ...@@ -1137,7 +1802,6 @@ def _sage_varlen_attention(
max_seqlen_k=max_seqlen_k, max_seqlen_k=max_seqlen_k,
is_causal=is_causal, is_causal=is_causal,
sm_scale=scale, sm_scale=scale,
smooth_k=smooth_k,
) )
out = out.unflatten(0, (batch_size, -1)) out = out.unflatten(0, (batch_size, -1))
...@@ -1154,11 +1818,8 @@ def _sage_qk_int8_pv_fp8_cuda_attention( ...@@ -1154,11 +1818,8 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
value: torch.Tensor, value: torch.Tensor,
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False, return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda( return sageattn_qk_int8_pv_fp8_cuda(
q=query, q=query,
...@@ -1166,11 +1827,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention( ...@@ -1166,11 +1827,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
v=value, v=value,
tensor_layout="NHD", tensor_layout="NHD",
is_causal=is_causal, is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale, sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse, return_lse=return_lse,
) )
...@@ -1185,10 +1842,8 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( ...@@ -1185,10 +1842,8 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
value: torch.Tensor, value: torch.Tensor,
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
smooth_k: bool = True,
return_lse: bool = False, return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90( return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query, q=query,
...@@ -1196,10 +1851,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( ...@@ -1196,10 +1851,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
v=value, v=value,
tensor_layout="NHD", tensor_layout="NHD",
is_causal=is_causal, is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale, sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
return_lse=return_lse, return_lse=return_lse,
) )
...@@ -1214,11 +1866,8 @@ def _sage_qk_int8_pv_fp16_cuda_attention( ...@@ -1214,11 +1866,8 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
value: torch.Tensor, value: torch.Tensor,
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False, return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda( return sageattn_qk_int8_pv_fp16_cuda(
q=query, q=query,
...@@ -1226,11 +1875,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention( ...@@ -1226,11 +1875,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
v=value, v=value,
tensor_layout="NHD", tensor_layout="NHD",
is_causal=is_causal, is_causal=is_causal,
qk_quant_gran=qk_quant_gran,
sm_scale=scale, sm_scale=scale,
pv_accum_dtype=pv_accum_dtype,
smooth_k=smooth_k,
smooth_v=smooth_v,
return_lse=return_lse, return_lse=return_lse,
) )
...@@ -1245,19 +1890,16 @@ def _sage_qk_int8_pv_fp16_triton_attention( ...@@ -1245,19 +1890,16 @@ def _sage_qk_int8_pv_fp16_triton_attention(
value: torch.Tensor, value: torch.Tensor,
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
smooth_k: bool = True,
return_lse: bool = False, return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton( return sageattn_qk_int8_pv_fp16_triton(
q=query, q=query,
k=key, k=key,
v=value, v=value,
tensor_layout="NHD", tensor_layout="NHD",
quantization_backend=quantization_backend,
is_causal=is_causal, is_causal=is_causal,
sm_scale=scale, sm_scale=scale,
smooth_k=smooth_k,
return_lse=return_lse, return_lse=return_lse,
) )
...@@ -1275,7 +1917,12 @@ def _xformers_attention( ...@@ -1275,7 +1917,12 @@ def _xformers_attention(
is_causal: bool = False, is_causal: bool = False,
scale: Optional[float] = None, scale: Optional[float] = None,
enable_gqa: bool = False, enable_gqa: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if return_lse:
raise ValueError("xformers attention backend does not support setting `return_lse=True`.")
batch_size, seq_len_q, num_heads_q, _ = query.shape batch_size, seq_len_q, num_heads_q, _ = query.shape
_, seq_len_kv, num_heads_kv, _ = key.shape _, seq_len_kv, num_heads_kv, _ = key.shape
......
...@@ -65,6 +65,7 @@ from ..utils.hub_utils import ( ...@@ -65,6 +65,7 @@ from ..utils.hub_utils import (
populate_model_card, populate_model_card,
) )
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import ( from .model_loading_utils import (
_caching_allocator_warmup, _caching_allocator_warmup,
_determine_device_map, _determine_device_map,
...@@ -248,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -248,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_skip_layerwise_casting_patterns = None _skip_layerwise_casting_patterns = None
_supports_group_offloading = True _supports_group_offloading = True
_repeated_blocks = [] _repeated_blocks = []
_parallel_config = None
_cp_plan = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -620,8 +623,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -620,8 +623,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def reset_attention_backend(self) -> None: def reset_attention_backend(self) -> None:
""" """
Resets the attention backend for the model. Following calls to `forward` will use the environment default or Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
the torch native scaled dot product attention. set, or the torch native scaled dot product attention.
""" """
from .attention import AttentionModuleMixin from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention from .attention_processor import Attention, MochiAttention
...@@ -960,6 +963,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -960,6 +963,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
quantization_config = kwargs.pop("quantization_config", None) quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage: if is_parallel_loading_enabled and not low_cpu_mem_usage:
...@@ -1340,6 +1344,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1340,6 +1344,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Set model in evaluation mode to deactivate DropOut modules by default # Set model in evaluation mode to deactivate DropOut modules by default
model.eval() model.eval()
if parallel_config is not None:
model.enable_parallelism(config=parallel_config)
if output_loading_info: if output_loading_info:
return model, loading_info return model, loading_info
...@@ -1478,6 +1485,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1478,6 +1485,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
) )
def enable_parallelism(
self,
*,
config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
)
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
if config.context_parallel_config is not None:
apply_context_parallel(self, config.context_parallel_config, cp_plan)
self._parallel_config = config
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_parallel_config"):
continue
processor._parallel_config = config
@classmethod @classmethod
def _load_pretrained_model( def _load_pretrained_model(
cls, cls,
......
...@@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed( ...@@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed(
class BriaAttnProcessor: class BriaAttnProcessor:
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
...@@ -161,7 +162,12 @@ class BriaAttnProcessor: ...@@ -161,7 +162,12 @@ class BriaAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn( hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
......
...@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st ...@@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
class FluxAttnProcessor: class FluxAttnProcessor:
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
...@@ -114,7 +116,12 @@ class FluxAttnProcessor: ...@@ -114,7 +116,12 @@ class FluxAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn( hidden_states = dispatch_attention_fn(
query, key, value, attn_mask=attention_mask, backend=self._attention_backend query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
...@@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module): ...@@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter.""" """Flux Attention processor for IP-Adapter."""
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__( def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
...@@ -220,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module): ...@@ -220,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
...@@ -252,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module): ...@@ -252,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
...@@ -556,6 +566,15 @@ class FluxTransformer2DModel( ...@@ -556,6 +566,15 @@ class FluxTransformer2DModel(
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor: ...@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor:
""" """
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if is_torch_version("<", "2.0"): if is_torch_version("<", "2.0"):
...@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor: ...@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
...@@ -409,6 +412,18 @@ class LTXVideoTransformer3DModel( ...@@ -409,6 +412,18 @@ class LTXVideoTransformer3DModel(
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"] _skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"] _repeated_blocks = ["LTXVideoTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"rope": {
0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention from ..attention_processor import Attention
...@@ -261,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0: ...@@ -261,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0:
""" """
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
...@@ -334,6 +336,7 @@ class QwenDoubleStreamAttnProcessor2_0: ...@@ -334,6 +336,7 @@ class QwenDoubleStreamAttnProcessor2_0:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
# Reshape back # Reshape back
...@@ -502,6 +505,18 @@ class QwenImageTransformer2DModel( ...@@ -502,6 +505,18 @@ class QwenImageTransformer2DModel(
_no_split_modules = ["QwenImageTransformerBlock"] _no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"] _repeated_blocks = ["QwenImageTransformerBlock"]
_cp_plan = {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states ...@@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states
class SkyReelsV2AttnProcessor: class SkyReelsV2AttnProcessor:
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
...@@ -139,6 +140,7 @@ class SkyReelsV2AttnProcessor: ...@@ -139,6 +140,7 @@ class SkyReelsV2AttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query) hidden_states_img = hidden_states_img.type_as(query)
...@@ -151,6 +153,7 @@ class SkyReelsV2AttnProcessor: ...@@ -151,6 +153,7 @@ class SkyReelsV2AttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
......
...@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
...@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t ...@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
class WanAttnProcessor: class WanAttnProcessor:
_attention_backend = None _attention_backend = None
_parallel_config = None
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
...@@ -132,6 +134,7 @@ class WanAttnProcessor: ...@@ -132,6 +134,7 @@ class WanAttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query) hidden_states_img = hidden_states_img.type_as(query)
...@@ -144,6 +147,7 @@ class WanAttnProcessor: ...@@ -144,6 +147,7 @@ class WanAttnProcessor:
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
backend=self._attention_backend, backend=self._attention_backend,
parallel_config=self._parallel_config,
) )
hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query) hidden_states = hidden_states.type_as(query)
...@@ -539,6 +543,19 @@ class WanTransformer3DModel( ...@@ -539,6 +543,19 @@ class WanTransformer3DModel(
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"] _repeated_blocks = ["WanTransformerBlock"]
_cp_plan = {
"rope": {
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
},
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -648,6 +648,21 @@ class ConsistencyDecoderVAE(metaclass=DummyObject): ...@@ -648,6 +648,21 @@ class ConsistencyDecoderVAE(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class ContextParallelConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ControlNetModel(metaclass=DummyObject): class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1053,6 +1068,21 @@ class OmniGenTransformer2DModel(metaclass=DummyObject): ...@@ -1053,6 +1068,21 @@ class OmniGenTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class ParallelConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject): class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment