Commit 0513d03d authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3321 canceled with stages
from .attn_layer import (
xFuserLongContextAttention,
)
__all__ = [
"xFuserLongContextAttention",
]
import torch
from torch import Tensor
import torch.distributed
from yunchang import LongContextAttention
from yunchang.comm.all_to_all import SeqAllToAll4D
from xfuser.logger import init_logger
logger = init_logger(__name__)
class xFuserLongContextAttention(LongContextAttention):
ring_impl_type_supported_kv_cache = ["basic"]
def __init__(
self,
scatter_idx: int = 2,
gather_idx: int = 1,
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
) -> None:
"""
Arguments:
scatter_idx: int = 2, the scatter dimension index for Ulysses All2All
gather_idx: int = 1, the gather dimension index for Ulysses All2All
ring_impl_type: str = "basic", the ring implementation type, currently only support "basic"
use_pack_qkv: bool = False, whether to use pack qkv in the input
use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion.
"""
super().__init__(
scatter_idx=scatter_idx,
gather_idx=gather_idx,
ring_impl_type=ring_impl_type,
use_pack_qkv=use_pack_qkv,
)
self.use_kv_cache = use_kv_cache
if (
use_kv_cache
and ring_impl_type not in self.ring_impl_type_supported_kv_cache
):
raise RuntimeError(
f"ring_impl_type: {ring_impl_type} do not support SP kv cache."
)
from xfuser.core.long_ctx_attention.ring import xdit_ring_flash_attn_func
self.ring_attn_fn = xdit_ring_flash_attn_func
@torch.compiler.disable
def forward(
self,
attn,
query: Tensor,
key: Tensor,
value: Tensor,
*,
joint_tensor_query=None,
joint_tensor_key=None,
joint_tensor_value=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
joint_strategy="none",
) -> Tensor:
"""forward
Arguments:
attn (Attention): the attention module
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args,
joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy
joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy
joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy,
*args: the args same as flash_attn_interface
joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear"
Returns:
* output (Tensor): context output
"""
is_joint = False
if (joint_tensor_query is not None and
joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
is_joint = True
else:
query = torch.cat([joint_tensor_query, query], dim=1)
is_joint = True
elif (joint_tensor_query is None and
joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_query, joint_tensor_key, and joint_tensor_value should be None or not None simultaneously."
)
if is_joint:
ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = (
joint_tensor_key.shape[-2] // ulysses_world_size
)
joint_tensor_key = joint_tensor_key[
...,
attn_heads_per_ulysses_rank
* ulysses_rank : attn_heads_per_ulysses_rank
* (ulysses_rank + 1),
:,
]
joint_tensor_value = joint_tensor_value[
...,
attn_heads_per_ulysses_rank
* ulysses_rank : attn_heads_per_ulysses_rank
* (ulysses_rank + 1),
:,
]
# 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
if self.use_pack_qkv:
# (3*bs, seq_len/N, head_cnt, head_size)
qkv = torch.cat([query, key, value]).continous()
# (3*bs, seq_len, head_cnt/N, head_size)
qkv = SeqAllToAll4D.apply(
self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx
)
qkv = torch.chunk(qkv, 3, dim=0)
query_layer, key_layer, value_layer = qkv
else:
query_layer = SeqAllToAll4D.apply(
self.ulysses_pg, query, self.scatter_idx, self.gather_idx
)
key_layer = SeqAllToAll4D.apply(
self.ulysses_pg, key, self.scatter_idx, self.gather_idx
)
value_layer = SeqAllToAll4D.apply(
self.ulysses_pg, value, self.scatter_idx, self.gather_idx
)
out = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
group=self.ring_pg,
attn_layer=attn if self.use_kv_cache else None,
joint_tensor_key=joint_tensor_key,
joint_tensor_value=joint_tensor_value,
joint_strategy=joint_strategy,
)
if type(out) == tuple:
context_layer, _, _ = out
else:
context_layer = out
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(
self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx
)
# out e.g., [s/p::h]
return output
from .ring_flash_attn import xdit_ring_flash_attn_func
__all__ = [
"xdit_ring_flash_attn_func",
]
import torch
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc
def xdit_ring_flash_attn_forward(
process_group,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
attn_layer=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="none",
):
is_joint = False
if (joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
else:
is_joint = True
elif (joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_key and joint_tensor_value should be None or not None simultaneously."
)
comm = RingComm(process_group)
out = None
lse = None
next_k, next_v = None, None
if attn_layer is not None:
k, v = get_cache_manager().update_and_get_kv_cache(
new_kv=[k, v],
layer=attn_layer,
slice_dim=1,
layer_type="attn",
)
k = k.contiguous()
v = v.contiguous()
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()
if is_joint and joint_strategy == "rear":
if step + 1 == comm.world_size:
key = torch.cat([k, joint_tensor_key], dim=1)
value = torch.cat([v, joint_tensor_value], dim=1)
else:
key, value = k, v
elif is_joint and joint_strategy == "front":
if step == 0:
key = torch.cat([joint_tensor_key, k], dim=1)
value = torch.cat([joint_tensor_value, v], dim=1)
else:
key, value = k, v
else:
key, value = k, v
if not causal or step <= comm.rank:
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
v = next_v
out = out.to(q.dtype)
lse = lse.squeeze(dim=-1).transpose(1, 2)
return out, lse
class xFuserRingFlashAttnFunc(RingFlashAttnFunc):
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
group,
attn_layer,
joint_tensor_key,
joint_tensor_value,
joint_strategy,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
assert alibi_slopes is None
if attn_layer is None:
k = k.contiguous()
v = v.contiguous()
out, softmax_lse = xdit_ring_flash_attn_forward(
group,
q,
k,
v,
softmax_scale=softmax_scale,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=False,
attn_layer=attn_layer,
joint_tensor_key=joint_tensor_key,
joint_tensor_value=joint_tensor_value,
joint_strategy=joint_strategy,
)
# this should be out_padded
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
ctx.group = group
return out if not return_softmax else (out, softmax_lse, None)
def xdit_ring_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_layer=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="none",
):
return xFuserRingFlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_layer,
joint_tensor_key,
joint_tensor_value,
joint_strategy,
)
from .attn_layer import xFuserUlyssesAttention
__all__ = [
"xFuserUlyssesAttention",
]
from typing import Any
import torch
import torch.distributed as dist
from torch import Tensor
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang import UlyssesAttention
from yunchang.globals import PROCESS_GROUP
from yunchang.comm.all_to_all import SeqAllToAll4D
try:
# yunchang > 0.4.0
from yunchang.kernels.attention import torch_attn
except:
print(f"detect you are not use the latest yunchang. Please install yunchang>=0.4.0")
try:
from yunchang.ulysses.attn_layer import torch_attn
except:
raise ImportError(f"yunchang import torch_attn error")
class xFuserUlyssesAttention(UlyssesAttention):
def __init__(
self,
scatter_idx: int = 2,
gather_idx: int = 1,
use_fa: bool = True,
use_kv_cache: bool = True,
) -> None:
super(UlyssesAttention, self).__init__()
self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.use_fa = use_fa
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(device)
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
self.use_fa = False
self.use_kv_cache = use_kv_cache
if self.use_fa:
from flash_attn import flash_attn_func
self.fn = flash_attn_func
else:
self.fn = torch_attn
def forward(
self,
attn,
query: Tensor,
key: Tensor,
value: Tensor,
*,
joint_tensor_query=None,
joint_tensor_key=None,
joint_tensor_value=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
joint_strategy="none",
) -> Tensor:
"""forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
if (
joint_tensor_key is not None
and joint_tensor_value is not None
and joint_tensor_query is not None
):
if joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
elif joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
elif joint_strategy == "none":
raise ValueError(
f"joint_strategy: {joint_strategy} not supported when joint tensors is not None."
)
else:
raise ValueError(f"joint_strategy: {joint_strategy} not supported.")
ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = (
joint_tensor_key.shape[-2] // ulysses_world_size
)
joint_tensor_key = joint_tensor_key[
...,
attn_heads_per_ulysses_rank
* ulysses_rank : attn_heads_per_ulysses_rank
* (ulysses_rank + 1),
:,
]
joint_tensor_value = joint_tensor_value[
...,
attn_heads_per_ulysses_rank
* ulysses_rank : attn_heads_per_ulysses_rank
* (ulysses_rank + 1),
:,
]
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(
self.ulysses_pg, query, self.scatter_idx, self.gather_idx
)
k = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx)
v = SeqAllToAll4D.apply(
self.ulysses_pg, value, self.scatter_idx, self.gather_idx
)
if self.use_kv_cache:
k, v = get_cache_manager().update_and_get_kv_cache(
new_kv=[k, v],
layer=attn,
slice_dim=1,
layer_type="attn",
)
if joint_strategy != "none":
if joint_strategy == "rear":
k = torch.cat([k, joint_tensor_key], dim=1)
v = torch.cat([v, joint_tensor_value], dim=1)
elif joint_strategy == "front":
k = torch.cat([joint_tensor_key, k], dim=1)
v = torch.cat([joint_tensor_value, v], dim=1)
context_layer = self.fn(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
if isinstance(context_layer, tuple):
context_layer = context_layer[0]
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(
self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx
)
# out e.g., [s/p::h]
return output
from .timer import gpu_timer_decorator
import torch
import time
def gpu_timer_decorator(func):
def wrapper(*args, **kwargs):
torch.cuda.synchronize()
start_time = time.time()
result = func(*args, **kwargs)
torch.cuda.synchronize()
end_time = time.time()
if torch.distributed.get_rank() == 0:
print(
f"{func.__name__} took {end_time - start_time} seconds to run on GPU."
)
return result
return wrapper
import os
import torch
import diffusers
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from packaging import version
from xfuser.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
MASTER_ADDR: str = ""
MASTER_PORT: Optional[int] = None
CUDA_HOME: Optional[str] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
XDIT_LOGGING_LEVEL: str = "INFO"
CUDA_VERSION: version.Version
TORCH_VERSION: version.Version
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Runtime Env Vars ==================
# used in distributed environment to determine the master address
"MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""),
# used in distributed environment to manually set the communication port
"MASTER_PORT": lambda: (
int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None
),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# this is used for configuring the default logging level
"XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}
variables: Dict[str, Callable[[], Any]] = {
# ================== Other Vars ==================
# used in version checking
# "CUDA_VERSION": lambda: version.parse(torch.version.cuda),
"CUDA_VERSION": "gfx928",
"TORCH_VERSION": lambda: version.parse(
version.parse(torch.__version__).base_version
),
}
class PackagesEnvChecker:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PackagesEnvChecker, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
self.packages_info = {
"has_flash_attn": self.check_flash_attn(),
"has_long_ctx_attn": self.check_long_ctx_attn(),
"diffusers_version": self.check_diffusers_version(),
}
def check_flash_attn(self):
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(device)
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
return False
else:
from flash_attn import flash_attn_func
from flash_attn import __version__
if __version__ < "2.6.0":
raise ImportError(f"install flash_attn >= 2.6.0")
return True
except ImportError:
logger.warning(
f'Flash Attention library "flash_attn" not found, '
f"using pytorch attention implementation"
)
return False
def check_long_ctx_attn(self):
try:
from yunchang import (
set_seq_parallel_pg,
ring_flash_attn_func,
UlyssesAttention,
LongContextAttention,
LongContextAttentionQKVPacked,
)
return True
except ImportError:
logger.warning(
f'Ring Flash Attention library "yunchang" not found, '
f"using pytorch attention implementation"
)
return False
def check_diffusers_version(self):
if version.parse(
version.parse(diffusers.__version__).base_version
) < version.parse("0.30.0"):
raise RuntimeError(
f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0"
)
return version.parse(version.parse(diffusers.__version__).base_version)
def get_packages_info(self):
return self.packages_info
PACKAGES_CHECKER = PackagesEnvChecker()
def __getattr__(name):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
if name in variables:
return variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration."""
import logging
import sys
import os
from typing import Optional
_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
_LOG_LEVEL = os.environ.get("LOG_LEVEL", "debug")
_LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0)
_LOG_DIR = os.environ.get("LOG_DIR", None)
class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None):
logging.Formatter.__init__(self, fmt, datefmt)
def format(self, record):
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
_root_logger = logging.getLogger("xfuser")
_default_handler = None
_default_file_handler = None
_inference_log_file_handler = {}
def _setup_logger():
_root_logger.setLevel(_LOG_LEVEL)
global _default_handler
global _default_file_handler
fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(_LOG_LEVEL)
_root_logger.addHandler(_default_handler)
if _default_file_handler is None and _LOG_DIR is not None:
if not os.path.exists(_LOG_DIR):
try:
os.makedirs(_LOG_DIR)
except OSError as e:
_root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}")
_default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log")
_default_file_handler.setLevel(_LOG_LEVEL)
_default_file_handler.setFormatter(fmt)
_root_logger.addHandler(_default_file_handler)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()
def init_logger(name: str):
pid = os.getpid()
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(_LOG_LEVEL)
logger.addHandler(_default_handler)
if _LOG_DIR is not None and pid is None:
logger.addHandler(_default_file_handler)
elif _LOG_DIR is not None:
if _inference_log_file_handler.get(pid, None) is not None:
logger.addHandler(_inference_log_file_handler[pid])
else:
if not os.path.exists(_LOG_DIR):
try:
os.makedirs(_LOG_DIR)
except OSError as e:
_root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}")
_inference_log_file_handler[pid] = logging.FileHandler(
_LOG_DIR + f"/process.{pid}.log"
)
_inference_log_file_handler[pid].setLevel(_LOG_LEVEL)
_inference_log_file_handler[pid].setFormatter(
NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
)
_root_logger.addHandler(_inference_log_file_handler[pid])
logger.addHandler(_inference_log_file_handler[pid])
logger.propagate = False
return logger
from abc import abstractmethod, ABCMeta
from functools import wraps
from typing import Any, List, Optional
from xfuser.core.distributed.parallel_state import (
get_classifier_free_guidance_world_size,
get_pipeline_parallel_world_size,
get_sequence_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.core.fast_attention import get_fast_attn_enable
class xFuserBaseWrapper(metaclass=ABCMeta):
def __init__(
self,
module: Any,
):
self.module = module
self.module_type = type(module)
def __getattr__(self, name: str):
try:
return getattr(self.module, name)
except RecursionError:
raise AttributeError(
f"module {type(self.module).__name__} has no " f"attribute {name}"
)
def __str__(self):
return str(self.module)
@staticmethod
def forward_check_condition(func):
@wraps(func)
def check_condition_fn(self, *args, **kwargs):
if (
get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
):
return func(self, *args, **kwargs)
if not get_runtime_state().is_ready():
raise ValueError(
"Runtime state is not ready, please call RuntimeState.set_input_parameters "
"before calling forward"
)
return func(self, *args, **kwargs)
return check_condition_fn
from .register import xFuserLayerWrappersRegister
from .base_layer import xFuserLayerBaseWrapper
from .attention_processor import xFuserAttentionWrapper
from .conv import xFuserConv2dWrapper
from .embeddings import xFuserPatchEmbedWrapper
from .feedforward import xFuserFeedForwardWrapper
__all__ = [
"xFuserLayerWrappersRegister",
"xFuserLayerBaseWrapper",
"xFuserAttentionWrapper",
"xFuserConv2dWrapper",
"xFuserPatchEmbedWrapper",
"xFuserFeedForwardWrapper",
]
import inspect
from typing import Optional, Union, Tuple
import torch
from torch import nn
import torch.distributed
from torch.nn import functional as F
from diffusers.utils import deprecate
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import (
AttnProcessor2_0,
JointAttnProcessor2_0,
FluxAttnProcessor2_0,
HunyuanAttnProcessor2_0,
CogVideoXAttnProcessor2_0
)
from diffusers.models.embeddings import apply_rotary_emb
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size
)
from xfuser.core.fast_attention import (
xFuserFastAttention,
get_fast_attn_enable,
)
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.model_executor.layers import xFuserLayerBaseWrapper
from xfuser.model_executor.layers import xFuserLayerWrappersRegister
from xfuser.logger import init_logger
from xfuser.envs import PACKAGES_CHECKER
if torch.__version__ >= '2.5.0':
from xfuser.model_executor.layers.usp import USP
else:
from xfuser.model_executor.layers.usp_legacy import USP
logger = init_logger(__name__)
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
HAS_FLASH_ATTN = env_info["has_flash_attn"]
def is_v100():
if not torch.cuda.is_available():
return False
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
return "V100" in device_name
def torch_compile_disable_if_v100(func):
if is_v100():
return torch.compiler.disable(func)
return func
class xFuserAttentionBaseWrapper(xFuserLayerBaseWrapper):
def __init__(
self,
attention: Attention,
):
super().__init__(module=attention)
to_k = self.module.to_k
to_v = self.module.to_v
assert isinstance(to_k, nn.Linear)
assert isinstance(to_v, nn.Linear)
assert (to_k.bias is None) == (to_v.bias is None)
assert to_k.weight.shape == to_v.weight.shape
class xFuserAttentionProcessorRegister:
_XFUSER_ATTENTION_PROCESSOR_MAPPING = {}
@classmethod
def register(cls, origin_processor_class):
def decorator(xfuser_processor):
if not issubclass(xfuser_processor, origin_processor_class):
raise ValueError(
f"{xfuser_processor.__class__.__name__} is not a subclass of origin class {origin_processor_class.__class__.__name__}"
)
cls._XFUSER_ATTENTION_PROCESSOR_MAPPING[origin_processor_class] = (
xfuser_processor
)
return xfuser_processor
return decorator
@classmethod
def get_processor(cls, processor):
for (
origin_processor_class,
xfuser_processor,
) in cls._XFUSER_ATTENTION_PROCESSOR_MAPPING.items():
if isinstance(processor, origin_processor_class):
return xfuser_processor
raise ValueError(
f"Attention Processor class {processor.__class__.__name__} is not supported by xFuser"
)
@xFuserLayerWrappersRegister.register(Attention)
class xFuserAttentionWrapper(xFuserAttentionBaseWrapper):
def __init__(
self,
attention: Attention,
latte_temporal_attention: bool = False,
):
super().__init__(attention=attention)
self.processor = xFuserAttentionProcessorRegister.get_processor(
attention.processor
)()
self.latte_temporal_attention = latte_temporal_attention
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(
inspect.signature(self.processor.__call__).parameters.keys()
)
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k
for k, _ in cross_attention_kwargs.items()
if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {
k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
latte_temporal_attention=self.latte_temporal_attention,
**cross_attention_kwargs,
)
@xFuserAttentionProcessorRegister.register(AttnProcessor2_0)
class xFuserAttnProcessor2_0(AttnProcessor2_0):
def __init__(self):
super().__init__()
use_long_ctx_attn_kvcache = True
self.use_long_ctx_attn_kvcache = (
HAS_LONG_CTX_ATTN
and use_long_ctx_attn_kvcache
and get_sequence_parallel_world_size() > 1
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)
if HAS_FLASH_ATTN:
# self.hybrid_seq_parallel_attn = LongContextAttention()
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
use_kv_cache=self.use_long_ctx_attn_kvcache,
)
else:
self.hybrid_seq_parallel_attn = None
if get_fast_attn_enable():
self.fast_attn = xFuserFastAttention()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
latte_temporal_attention: Optional[bool] = False,
*args,
**kwargs,
):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
#! ---------------------------------------- Fast Attention ----------------------------------------
if get_fast_attn_enable():
return self.fast_attn(
attn,
hidden_states,
encoder_hidden_states,
attention_mask,
temb,
*args,
**kwargs,
)
#! ---------------------------------------- Fast Attention ----------------------------------------
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
layer_type="attn",
)
#! ---------------------------------------- KV CACHE ----------------------------------------
#! ---------------------------------------- ATTENTION ----------------------------------------
if (
HAS_LONG_CTX_ATTN
and get_sequence_parallel_world_size() > 1
and not latte_temporal_attention
):
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_strategy="none",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.module.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
#! ORIGIN
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# # the output of sdp = (batch, num_heads, seq_len, head_dim)
# # TODO: add support for attn.scale when we move to Torch 2.1
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@xFuserAttentionProcessorRegister.register(JointAttnProcessor2_0)
class xFuserJointAttnProcessor2_0(JointAttnProcessor2_0):
def __init__(self):
super().__init__()
use_long_ctx_attn_kvcache = True
self.use_long_ctx_attn_kvcache = (
HAS_LONG_CTX_ATTN
and use_long_ctx_attn_kvcache
and get_sequence_parallel_world_size() > 1
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)
if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
use_kv_cache=self.use_long_ctx_attn_kvcache,
)
if get_fast_attn_enable():
self.fast_attn = xFuserFastAttention()
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=1,
layer_type="attn",
)
#! ---------------------------------------- KV CACHE ----------------------------------------
#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
if get_runtime_state().split_text_embed_in_sp:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_tensor_query=encoder_hidden_states_query_proj,
joint_tensor_key=encoder_hidden_states_key_proj,
joint_tensor_value=encoder_hidden_states_value_proj,
joint_strategy="rear",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.module.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
#! ORIGIN
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# hidden_states = hidden_states = F.scaled_dot_product_attention(
# query, key, value, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states, encoder_hidden_states
@xFuserAttentionProcessorRegister.register(FluxAttnProcessor2_0)
class xFuserFluxAttnProcessor2_0(FluxAttnProcessor2_0):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
super().__init__()
use_long_ctx_attn_kvcache = True
self.use_long_ctx_attn_kvcache = (
HAS_LONG_CTX_ATTN
and use_long_ctx_attn_kvcache
and get_sequence_parallel_world_size() > 1
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)
if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
use_kv_cache=self.use_long_ctx_attn_kvcache,
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
num_encoder_hidden_states_tokens = encoder_hidden_states_query_proj.shape[2]
num_query_tokens = query.shape[2]
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
num_encoder_hidden_states_tokens = (
get_runtime_state().max_condition_sequence_length
)
num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
#! ---------------------------------------- KV CACHE ----------------------------------------
if get_runtime_state().num_pipeline_patch > 1 and not self.use_long_ctx_attn_kvcache:
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
layer_type="attn",
)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
#! ---------------------------------------- KV CACHE ----------------------------------------
#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj, query = query.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
hidden_states = self.hybrid_seq_parallel_attn(
attn if get_runtime_state().num_pipeline_patch > 1 else None,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_tensor_query=encoder_hidden_states_query_proj,
joint_tensor_key=encoder_hidden_states_key_proj,
joint_tensor_value=encoder_hidden_states_value_proj,
joint_strategy="front",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
#! ---------------------------------------- ATTENTION ----------------------------------------
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
@xFuserAttentionProcessorRegister.register(HunyuanAttnProcessor2_0)
class xFuserHunyuanAttnProcessor2_0(HunyuanAttnProcessor2_0):
def __init__(self):
super().__init__()
use_long_ctx_attn_kvcache = True
self.use_long_ctx_attn_kvcache = (
HAS_LONG_CTX_ATTN
and use_long_ctx_attn_kvcache
and get_sequence_parallel_world_size() > 1
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)
if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
use_kv_cache=self.use_long_ctx_attn_kvcache,
)
else:
self.hybrid_seq_parallel_attn = None
# NOTE() torch.compile dose not works for V100
@torch_compile_disable_if_v100
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
latte_temporal_attention: Optional[bool] = False,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
# print(f"Q {query.shape}, {key.shape}, {image_rotary_emb[0].shape}")
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
layer_type="attn",
)
#! ---------------------------------------- KV CACHE ----------------------------------------
#! ---------------------------------------- ATTENTION ----------------------------------------
if (
HAS_LONG_CTX_ATTN
and get_sequence_parallel_world_size() > 1
and not attn.is_cross_attention
and not latte_temporal_attention
):
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_strategy="none",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.module.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
#! ORIGIN
# # the output of sdp = (batch, num_heads, seq_len, head_dim)
# # TODO: add support for attn.scale when we move to Torch 2.1
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
@xFuserAttentionProcessorRegister.register(CogVideoXAttnProcessor2_0)
class xFuserCogVideoXAttnProcessor2_0(CogVideoXAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
super().__init__()
use_long_ctx_attn_kvcache = True
self.use_long_ctx_attn_kvcache = (
HAS_LONG_CTX_ATTN
and use_long_ctx_attn_kvcache
and get_sequence_parallel_world_size() > 1
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)
if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
use_kv_cache=self.use_long_ctx_attn_kvcache,
)
else:
self.hybrid_seq_parallel_attn = None
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
latent_seq_length = hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)
#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
if get_runtime_state().split_text_embed_in_sp:
encoder_query = None
encoder_key = None
encoder_value = None
else:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]
encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = self.hybrid_seq_parallel_attn(
None,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_tensor_query=encoder_query,
joint_tensor_key=encoder_key,
joint_tensor_value=encoder_value,
joint_strategy="front",
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
#! ORIGIN
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------
assert text_seq_length + latent_seq_length == hidden_states.shape[1]
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, latent_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
from abc import abstractmethod, ABCMeta
from typing import List
import torch
import torch.nn as nn
from xfuser.config.config import InputConfig, ParallelConfig, RuntimeConfig
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
class xFuserLayerBaseWrapper(nn.Module, xFuserBaseWrapper, metaclass=ABCMeta):
def __init__(self, module: nn.Module):
super().__init__()
super(nn.Module, self).__init__(module=module)
self.activation_cache = None
def __getattr__(self, name: str):
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
return _parameters[name]
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
return _buffers[name]
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
try:
return getattr(self.module, name)
except RecursionError:
raise AttributeError(
f"module {type(self.module).__name__} has no " f"attribute {name}"
)
@abstractmethod
def forward(self, *args, **kwargs):
pass
import torch
from torch import nn
from torch.nn import functional as F
from xfuser.config import ParallelConfig, RuntimeConfig
from xfuser.core.distributed.parallel_state import get_sequence_parallel_world_size
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.model_executor.layers import xFuserLayerBaseWrapper
from xfuser.logger import init_logger
from xfuser.model_executor.layers import xFuserLayerWrappersRegister
from xfuser.core.distributed import (
get_pipeline_parallel_world_size,
)
logger = init_logger(__name__)
@xFuserLayerWrappersRegister.register(nn.Conv2d)
class xFuserConv2dWrapper(xFuserLayerBaseWrapper):
def __init__(
self,
conv2d: nn.Conv2d,
*,
is_first_layer: bool = True,
):
super().__init__(
module=conv2d,
)
self.is_first_layer = is_first_layer
def naive_forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W]
output = self.module(x)
return output
# TODO fix implementation problems in sliced_forward
# only available for patchify process
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
stride = self.module.stride[0]
padding = self.module.padding[0]
idx = get_runtime_state().pipeline_patch_idx
pp_patches_start_idx_local = get_runtime_state().pp_patches_start_idx_local
h_begin = pp_patches_start_idx_local[idx] - padding
h_end = pp_patches_start_idx_local[idx + 1] + padding
final_padding = [padding, padding, 0, 0]
if h_begin < 0:
h_begin = 0
final_padding[2] = padding
if h_end > h:
h_end = h
final_padding[3] = padding
sliced_input = x[:, :, h_begin:h_end, :]
padded_input = F.pad(sliced_input, final_padding, mode="constant")
result = F.conv2d(
padded_input,
self.module.weight,
self.module.bias,
stride=stride,
padding="valid",
)
return result
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if (
(
get_pipeline_parallel_world_size() == 1
and get_sequence_parallel_world_size() == 1
)
or self.module.kernel_size == (1, 1)
or self.module.kernel_size == 1
):
output = self.naive_forward(x)
else:
if self.is_first_layer:
if (
not get_runtime_state().patch_mode
or get_runtime_state().num_pipeline_patch == 1
):
self.activation_cache = x
output = self.naive_forward(self.activation_cache)
else:
if self.activation_cache is None:
self.activation_cache = torch.zeros(
[
x.shape[0],
x.shape[1],
get_runtime_state().pp_patches_start_idx_local[-1],
x.shape[3],
],
dtype=x.dtype,
device=x.device,
)
self.activation_cache[
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[
get_runtime_state().pipeline_patch_idx + 1
],
:,
] = x
output = self.sliced_forward(self.activation_cache)
else:
raise NotImplementedError
# else:
# boundary_size = self.module.padding[0]
# # if self.buffer_list is None:
# # if self.comm_manager.buffer_list is None: # # self.idx = self.comm_manager.register_tensor( # # shape=[2, x.shape[0], x.shape[1], boundary_size, x.shape[3]], # # torch_dtype=x.dtype,
# # layer_type="conv2d",
# # )
# # else:
# # self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
# if self.buffer_list is None:
# output = self.naive_forward(x)
# self.buffer_list = x
# else:
# def create_padded_x():
# if distri_config.split_idx() == 0:
# concat_x = torch.cat([x, self.buffer_list[distri_config.split_idx() + 1][0]], dim=2)
# padded_x = F.pad(concat_x, [0, 0, boundary_size, 0], mode="constant")
# elif distri_config.split_idx() == distri_config.n_device_per_batch - 1:
# concat_x = torch.cat([self.buffer_list[distri_config.split_idx() - 1][1], x], dim=2)
# padded_x = F.pad(concat_x, [0, 0, 0, boundary_size], mode="constant")
# else:
# padded_x = torch.cat(
# [
# self.buffer_list[distri_config.split_idx() - 1][1],
# x,
# self.buffer_list[distri_config.split_idx() + 1][0],
# ],
# dim=2,
# )
# return padded_x
# boundary = torch.stack([x[:, :, :boundary_size, :], x[:, :, -boundary_size:, :]], dim=0)
# if distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
# dist.all_gather(self.buffer_list, boundary, group=distri_config.batch_group, async_op=False)
# padded_x = create_padded_x()
# output = F.conv2d(
# padded_x,
# self.module.weight,
# self.module.bias,
# stride=self.module.stride[0],
# padding=(0, self.module.padding[1]),
# )
# else:
# padded_x = create_padded_x()
# output = F.conv2d(
# padded_x,
# self.module.weight,
# self.module.bias,
# stride=self.module.stride[0],
# padding=(0, self.module.padding[1]),
# )
# if distri_config.mode != "no_sync":
# self.comm_manager.enqueue(self.idx, boundary)
return output
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0/src/diffusers/models/embeddings.py
import torch
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed, CogVideoXPatchEmbed
import torch.distributed
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.model_executor.layers import xFuserLayerBaseWrapper
from xfuser.model_executor.layers import xFuserLayerWrappersRegister
from xfuser.logger import init_logger
logger = init_logger(__name__)
@xFuserLayerWrappersRegister.register(PatchEmbed)
class xFuserPatchEmbedWrapper(xFuserLayerBaseWrapper):
def __init__(
self,
patch_embedding: PatchEmbed,
):
super().__init__(
module=patch_embedding,
)
self.module: PatchEmbed
self.pos_embed = None
def forward(self, latent):
height = (
get_runtime_state().input_config.height
// get_runtime_state().vae_scale_factor
)
width = latent.shape[-1]
if not get_runtime_state().patch_mode:
if getattr(self.module, "pos_embed_max_size", None) is not None:
pass
else:
height, width = (
height // self.module.patch_size,
width // self.module.patch_size,
)
else:
if getattr(self.module, "pos_embed_max_size", None) is not None:
pass
else:
height, width = (
height // self.module.patch_size,
width // self.module.patch_size,
)
latent = self.module.proj(latent)
if self.module.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.module.layer_norm:
# TODO: NOT SURE whether compatible with norm
latent = self.module.norm(latent)
# [2, 4096 / c, 1152]
if self.module.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
# TODO: There might be a more faster way to generate a smaller pos_embed
if getattr(self.module, "pos_embed_max_size", None):
pos_embed = self.module.cropped_pos_embed(height, width)
else:
if self.module.height != height or self.module.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.module.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.module.base_size,
interpolation_scale=self.module.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
self.module.pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
self.module.height = height
self.module.width = width
pos_embed = self.module.pos_embed
else:
pos_embed = self.module.pos_embed
b, c, h = pos_embed.shape
if get_runtime_state().patch_mode:
start, end = get_runtime_state().pp_patches_token_start_end_idx_global[
get_runtime_state().pipeline_patch_idx
]
pos_embed = pos_embed[
:,
start:end,
:,
]
else:
pos_embed_list = [
pos_embed[
:,
get_runtime_state()
.pp_patches_token_start_end_idx_global[i][0] : get_runtime_state()
.pp_patches_token_start_end_idx_global[i][1],
:,
]
for i in range(get_runtime_state().num_pipeline_patch)
]
pos_embed = torch.cat(pos_embed_list, dim=1)
return (latent + pos_embed).to(latent.dtype)
@xFuserLayerWrappersRegister.register(CogVideoXPatchEmbed)
class xFuserCogVideoXPatchEmbedWrapper(xFuserLayerBaseWrapper):
def __init__(
self,
patch_embedding: CogVideoXPatchEmbed,
):
super().__init__(
module=patch_embedding,
)
self.module: CogVideoXPatchEmbed
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
# height is the height of a batch on a GPU, sum_height is the total height of the video
sum_height = (
get_runtime_state().input_config.height
// get_runtime_state().vae_scale_factor_spatial
)
text_embeds = self.text_proj(text_embeds)
batch_size, num_frames, channels, height, width = image_embeds.shape
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != sum_height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != sum_height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(sum_height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
# extract the image part of the positional embedding
pos_embedding = pos_embedding[:, self.max_text_seq_length :]
# slice the positional embedding
post_patch_height = sum_height // self.patch_size
post_patch_width = width // self.patch_size
post_time_compression_frames = (pre_time_compression_frames - 1) // self.temporal_compression_ratio + 1
pos_embed_list = [
pos_embedding[
:,
post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][0]:
post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][1],
:,
]
for i in range(post_time_compression_frames)
]
pos_embedding = torch.cat(pos_embed_list, dim=1)
embeds[:, self.max_text_seq_length :] += pos_embedding
return embeds
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from diffusers.models.attention import FeedForward, GELU, GEGLU
from torch import nn
from xfuser.core.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
get_tp_group,
)
import torch
from xfuser.model_executor.layers.base_layer import xFuserLayerBaseWrapper
from xfuser.model_executor.layers.register import xFuserLayerWrappersRegister
@xFuserLayerWrappersRegister.register(FeedForward)
class xFuserFeedForwardWrapper(xFuserLayerBaseWrapper):
def __init__(self, feedforward: FeedForward):
super(xFuserFeedForwardWrapper, self).__init__(module=feedforward)
tp_degree = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
if isinstance(self.module.net[0], GELU):
self.module.net[0].proj.weight.data = self.module.net[
0
].proj.weight.data.chunk(tp_degree, dim=0)[tp_rank]
if self.module.net[0].proj.bias is not None:
self.module.net[0].proj.bias.data = self.module.net[
0
].proj.bias.data.chunk(tp_degree, dim=0)[tp_rank]
elif isinstance(self.module.net[0], GEGLU):
weight_buff = self.module.net[0].proj.weight.data.chunk(2, dim=0)
a = weight_buff[0].chunk(tp_degree, dim=0)[tp_rank]
b = weight_buff[1].chunk(tp_degree, dim=0)[tp_rank]
c = torch.cat([a, b], dim=0)
self.module.net[0].proj.weight.data = c
bias_buff = self.module.net[0].proj.bias.data.chunk(2, dim=0)
a = bias_buff[0].chunk(tp_degree, dim=0)[tp_rank]
b = bias_buff[1].chunk(tp_degree, dim=0)[tp_rank]
c = torch.cat([a, b], dim=0)
self.module.net[0].proj.bias.data = c
else:
raise TypeError(
f"activation_fn {type(isinstance(self.module.net[0]))} not supported"
)
self.module.net[2].weight.data = self.module.net[2].weight.chunk(
tp_degree, dim=1
)[tp_rank]
self.has_output_bias = False
if self.module.net[2].bias is not None:
self.register_parameter(
"output_bias", nn.Parameter(self.module.net[2].bias.data.clone())
)
self.module.net[2].bias = None
self.has_output_bias = True
torch.cuda.empty_cache()
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
hidden_states = self.module(hidden_states, *args, **kwargs)
get_tp_group().all_reduce(hidden_states)
if self.has_output_bias:
hidden_states += self.output_bias
return hidden_states
from typing import Dict, Type
import torch
import torch.nn as nn
from xfuser.logger import init_logger
from xfuser.model_executor.layers.base_layer import xFuserLayerBaseWrapper
logger = init_logger(__name__)
class xFuserLayerWrappersRegister:
_XFUSER_LAYER_MAPPING: Dict[
Type[nn.Module], Type[xFuserLayerBaseWrapper]
] = {}
@classmethod
def register(cls, origin_layer_class: Type[nn.Module]):
def decorator(xfuser_layer_wrapper: Type[xFuserLayerBaseWrapper]):
if not issubclass(xfuser_layer_wrapper, xFuserLayerBaseWrapper):
raise ValueError(
f"{xfuser_layer_wrapper.__class__.__name__} is not a "
f"subclass of xFuserLayerBaseWrapper"
)
cls._XFUSER_LAYER_MAPPING[origin_layer_class] = xfuser_layer_wrapper
return xfuser_layer_wrapper
return decorator
@classmethod
def get_wrapper(cls, layer: nn.Module) -> xFuserLayerBaseWrapper:
candidate = None
candidate_origin = None
for (
origin_layer_class,
xfuser_layer_wrapper,
) in cls._XFUSER_LAYER_MAPPING.items():
if isinstance(layer, origin_layer_class):
if (
(candidate is None and candidate_origin is None)
or origin_layer_class == layer.__class__
or issubclass(origin_layer_class, candidate_origin)
):
candidate_origin = origin_layer_class
candidate = xfuser_layer_wrapper
if candidate is None:
raise ValueError(
f"Layer class {layer.__class__.__name__} "
f"is not supported by xFuser"
)
else:
return candidate
# This file implements USP with torch version >= '2.5.0'
import torch
from torch.nn import functional as F
from torch.distributed.tensor.experimental._attention import _templated_ring_attention
aten = torch.ops.aten
import torch.distributed._functional_collectives as ft_c
from yunchang.globals import PROCESS_GROUP
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_ulysses_parallel_world_size,
get_ring_parallel_world_size,
)
def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
dropout_p=dropout_p,
is_causal=is_causal
)
return out
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
"""
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
so we cannot call ``wait()``.
"""
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
return tensor.wait()
return tensor
def _sdpa_all_to_all_single(x):
x_shape = x.shape
x = x.flatten()
x = ft_c.all_to_all_single(x, output_split_sizes=None, input_split_sizes=None, group=PROCESS_GROUP.ULYSSES_PG)
x = _maybe_wait(x)
x = x.reshape(x_shape)
return x
def _ft_c_input_all_to_all(x):
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x
assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim)
b, h, s, d = x.shape
assert h % world_size == 0, "h must be divisible by world_size, got {} and {}".format(h, world_size)
x = x.permute(1, 0, 2, 3).contiguous()
x = _sdpa_all_to_all_single(x)
x = x.reshape(world_size, h // world_size, b, -1, d).permute(2, 1, 0, 3, 4).reshape(b, h // world_size, -1, d)
return x
def _ft_c_output_all_to_all(x):
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x
assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim)
b, h, s, d = x.shape
assert s % world_size == 0, "s must be divisible by world_size, got {} and {}".format(s, world_size)
x = x.permute(2, 0, 1, 3).contiguous()
x = _sdpa_all_to_all_single(x)
x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d)
return x
def USP(query, key, value, dropout_p=0.0, is_causal=False):
if get_sequence_parallel_world_size() == 1:
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)
elif get_ulysses_parallel_world_size() == 1:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
elif get_ulysses_parallel_world_size() > 1:
query = _ft_c_input_all_to_all(query)
key = _ft_c_input_all_to_all(key)
value = _ft_c_input_all_to_all(value)
if get_ring_parallel_world_size() == 1:
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)
else:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
out = _ft_c_output_all_to_all(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment