Commit 0513d03d authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3321 canceled with stages
from xfuser.model_executor.pipelines import (
xFuserPixArtAlphaPipeline,
xFuserPixArtSigmaPipeline,
xFuserStableDiffusion3Pipeline,
xFuserFluxPipeline,
xFuserLattePipeline,
xFuserHunyuanDiTPipeline,
xFuserCogVideoXPipeline,
)
from xfuser.config import xFuserArgs, EngineConfig
from xfuser.parallel import xDiTParallel
__all__ = [
"xFuserPixArtAlphaPipeline",
"xFuserPixArtSigmaPipeline",
"xFuserStableDiffusion3Pipeline",
"xFuserFluxPipeline",
"xFuserLattePipeline",
"xFuserHunyuanDiTPipeline",
"xFuserCogVideoXPipeline",
"xFuserArgs",
"EngineConfig",
"xDiTParallel",
]
__version__ = "0.4.0"
from .args import FlexibleArgumentParser, xFuserArgs
from .config import (
EngineConfig,
ParallelConfig,
TensorParallelConfig,
PipeFusionParallelConfig,
SequenceParallelConfig,
DataParallelConfig,
ModelConfig,
InputConfig,
RuntimeConfig
)
__all__ = [
"FlexibleArgumentParser",
"xFuserArgs",
"EngineConfig",
"ParallelConfig",
"TensorParallelConfig",
"PipeFusionParallelConfig",
"SequenceParallelConfig",
"DataParallelConfig",
"ModelConfig",
"InputConfig",
"RuntimeConfig"
]
\ No newline at end of file
import sys
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional, List, Tuple, Union
import torch
import torch.distributed
from xfuser.logger import init_logger
from xfuser.core.distributed import init_distributed_environment
from xfuser.config.config import (
EngineConfig,
FastAttnConfig,
ParallelConfig,
TensorParallelConfig,
PipeFusionParallelConfig,
SequenceParallelConfig,
DataParallelConfig,
ModelConfig,
InputConfig,
RuntimeConfig,
)
logger = init_logger(__name__)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith("--"):
if "=" in arg:
key, value = arg.split("=", 1)
key = "--" + key[len("--") :].replace("-", "_")
processed_args.append(f"{key}={value}")
else:
processed_args.append("--" + arg[len("--") :].replace("-", "_"))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass
class xFuserArgs:
"""Arguments for xFuser engine."""
# Model arguments
model: str
download_dir: Optional[str] = None
trust_remote_code: bool = False
# Runtime arguments
warmup_steps: int = 1
# use_cuda_graph: bool = True
use_parallel_vae: bool = False
# use_profiler: bool = False
use_torch_compile: bool = False
use_onediff: bool = False
# Parallel arguments
# data parallel
data_parallel_degree: int = 1
use_cfg_parallel: bool = False
# sequence parallel
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
# tensor parallel
tensor_parallel_degree: int = 1
split_scheme: Optional[str] = "row"
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = None
# Input arguments
height: int = 1024
width: int = 1024
num_frames: int = 49
num_inference_steps: int = 20
max_sequence_length: int = 256
prompt: Union[str, List[str]] = ""
negative_prompt: Union[str, List[str]] = ""
no_use_resolution_binning: bool = False
seed: int = 42
output_type: str = "pil"
enable_model_cpu_offload: bool = False
enable_sequential_cpu_offload: bool = False
enable_tiling: bool = False
enable_slicing: bool = False
# DiTFastAttn arguments
use_fast_attn: bool = False
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
use_fp8_t5_encoder: bool = False
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser):
"""Shared CLI arguments for xFuser engine."""
# Model arguments
model_group = parser.add_argument_group("Model Options")
model_group.add_argument(
"--model",
type=str,
default="PixArt-alpha/PixArt-XL-2-1024-MS",
help="Name or path of the huggingface model to use.",
)
model_group.add_argument(
"--download-dir",
type=nullable_str,
default=xFuserArgs.download_dir,
help="Directory to download and load the weights, default to the default cache dir of huggingface.",
)
model_group.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from huggingface.",
)
# Runtime arguments
runtime_group = parser.add_argument_group("Runtime Options")
runtime_group.add_argument(
"--warmup_steps", type=int, default=1, help="Warmup steps in generation."
)
# runtime_group.add_argument("--use_cuda_graph", action="store_true")
runtime_group.add_argument("--use_parallel_vae", action="store_true")
# runtime_group.add_argument("--use_profiler", action="store_true")
runtime_group.add_argument(
"--use_torch_compile",
action="store_true",
help="Enable torch.compile to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_onediff",
action="store_true",
help="Enable onediff to accelerate inference in a single card",
)
# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
parallel_group.add_argument(
"--use_cfg_parallel",
action="store_true",
help="Use split batch in classifier_free_guidance. cfg_degree will be 2 if set",
)
parallel_group.add_argument(
"--data_parallel_degree", type=int, default=1, help="Data parallel degree."
)
parallel_group.add_argument(
"--ulysses_degree",
type=int,
default=None,
help="Ulysses sequence parallel degree. Used in attention layer.",
)
parallel_group.add_argument(
"--ring_degree",
type=int,
default=None,
help="Ring sequence parallel degree. Used in attention layer.",
)
parallel_group.add_argument(
"--pipefusion_parallel_degree",
type=int,
default=1,
help="Pipefusion parallel degree. Indicates the number of pipeline stages.",
)
parallel_group.add_argument(
"--num_pipeline_patch",
type=int,
default=None,
help="Number of patches the feature map should be segmented in pipefusion parallel.",
)
parallel_group.add_argument(
"--attn_layer_num_for_pp",
default=None,
nargs="*",
type=int,
help="List representing the number of layers per stage of the pipeline in pipefusion parallel",
)
parallel_group.add_argument(
"--tensor_parallel_degree",
type=int,
default=1,
help="Tensor parallel degree.",
)
parallel_group.add_argument(
"--split_scheme",
type=str,
default="row",
help="Split scheme for tensor parallel.",
)
# Input arguments
input_group = parser.add_argument_group("Input Options")
input_group.add_argument(
"--height", type=int, default=1024, help="The height of image"
)
input_group.add_argument(
"--width", type=int, default=1024, help="The width of image"
)
input_group.add_argument(
"--num_frames", type=int, default=49, help="The frames of video"
)
input_group.add_argument(
"--prompt", type=str, nargs="*", default="", help="Prompt for the model."
)
input_group.add_argument("--no_use_resolution_binning", action="store_true")
input_group.add_argument(
"--negative_prompt",
type=str,
nargs="*",
default="",
help="Negative prompt for the model.",
)
input_group.add_argument(
"--num_inference_steps",
type=int,
default=20,
help="Number of inference steps.",
)
input_group.add_argument(
"--max_sequence_length",
type=int,
default=256,
help="Max sequencen length of prompt",
)
runtime_group.add_argument(
"--seed", type=int, default=42, help="Random seed for operations."
)
runtime_group.add_argument(
"--output_type",
type=str,
default="pil",
help="Output type of the pipeline.",
)
runtime_group.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Offloading the weights to the CPU.",
)
runtime_group.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Offloading the weights to the CPU.",
)
runtime_group.add_argument(
"--enable_tiling",
action="store_true",
help="Making VAE decode a tile at a time to save GPU memory.",
)
runtime_group.add_argument(
"--enable_slicing",
action="store_true",
help="Making VAE decode a tile at a time to save GPU memory.",
)
runtime_group.add_argument(
"--use_fp8_t5_encoder",
action="store_true",
help="Quantize the T5 text encoder.",
)
# DiTFastAttn arguments
fast_attn_group = parser.add_argument_group("DiTFastAttn Options")
fast_attn_group.add_argument(
"--use_fast_attn",
action="store_true",
help="Use DiTFastAttn to accelerate single inference. Only data parallelism can be used with DITFastAttn.",
)
fast_attn_group.add_argument(
"--n_calib",
type=int,
default=8,
help="Number of prompts for compression method seletion.",
)
fast_attn_group.add_argument(
"--threshold",
type=float,
default=0.5,
help="Threshold for selecting attention compression method.",
)
fast_attn_group.add_argument(
"--window_size",
type=int,
default=64,
help="Size of window attention.",
)
fast_attn_group.add_argument(
"--coco_path",
type=str,
default=None,
help="Path of MS COCO annotation json file.",
)
fast_attn_group.add_argument(
"--use_cache",
action="store_true",
help="Use cache config for attention compression.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_config(
self,
) -> Tuple[EngineConfig, InputConfig]:
if not torch.distributed.is_initialized():
logger.warning(
"Distributed environment is not initialized. " "Initializing..."
)
init_distributed_environment()
model_config = ModelConfig(
model=self.model,
download_dir=self.download_dir,
trust_remote_code=self.trust_remote_code,
)
runtime_config = RuntimeConfig(
warmup_steps=self.warmup_steps,
# use_cuda_graph=self.use_cuda_graph,
use_parallel_vae=self.use_parallel_vae,
use_torch_compile=self.use_torch_compile,
use_onediff=self.use_onediff,
# use_profiler=self.use_profiler,
use_fp8_t5_encoder=self.use_fp8_t5_encoder,
)
parallel_config = ParallelConfig(
dp_config=DataParallelConfig(
dp_degree=self.data_parallel_degree,
use_cfg_parallel=self.use_cfg_parallel,
),
sp_config=SequenceParallelConfig(
ulysses_degree=self.ulysses_degree,
ring_degree=self.ring_degree,
),
tp_config=TensorParallelConfig(
tp_degree=self.tensor_parallel_degree,
split_scheme=self.split_scheme,
),
pp_config=PipeFusionParallelConfig(
pp_degree=self.pipefusion_parallel_degree,
num_pipeline_patch=self.num_pipeline_patch,
attn_layer_num_for_pp=self.attn_layer_num_for_pp,
),
)
fast_attn_config = FastAttnConfig(
use_fast_attn=self.use_fast_attn,
n_step=self.num_inference_steps,
n_calib=self.n_calib,
threshold=self.threshold,
window_size=self.window_size,
coco_path=self.coco_path,
use_cache=self.use_cache,
)
engine_config = EngineConfig(
model_config=model_config,
runtime_config=runtime_config,
parallel_config=parallel_config,
fast_attn_config=fast_attn_config,
)
input_config = InputConfig(
height=self.height,
width=self.width,
num_frames=self.num_frames,
use_resolution_binning=not self.no_use_resolution_binning,
batch_size=len(self.prompt) if isinstance(self.prompt, list) else 1,
prompt=self.prompt,
negative_prompt=self.negative_prompt,
num_inference_steps=self.num_inference_steps,
max_sequence_length=self.max_sequence_length,
seed=self.seed,
output_type=self.output_type,
)
return engine_config, input_config
import os
import torch
import torch.distributed as dist
from packaging import version
from dataclasses import dataclass, fields
from torch import distributed as dist
from xfuser.logger import init_logger
import xfuser.envs as envs
# from xfuser.envs import CUDA_VERSION, TORCH_VERSION, PACKAGES_CHECKER
from xfuser.envs import TORCH_VERSION, PACKAGES_CHECKER
logger = init_logger(__name__)
from typing import Union, Optional, List
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
HAS_FLASH_ATTN = env_info["has_flash_attn"]
def check_packages():
import diffusers
if not version.parse(diffusers.__version__) > version.parse("0.30.2"):
raise RuntimeError(
"This project requires diffusers version > 0.30.2. Currently, you can not install a correct version of diffusers by pip install."
"Please install it from source code!"
)
def check_env():
# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
#if CUDA_VERSION < version.parse("11.3"):
# raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above")
if TORCH_VERSION < version.parse("2.2.0"):
# https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
raise RuntimeError(
"CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. "
"If it is not released yet, please install nightly built PyTorch "
"with `pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cu121`"
)
@dataclass
class ModelConfig:
model: str
download_dir: Optional[str] = None
trust_remote_code: bool = False
@dataclass
class RuntimeConfig:
warmup_steps: int = 1
dtype: torch.dtype = torch.float16
use_cuda_graph: bool = False
use_parallel_vae: bool = False
use_profiler: bool = False
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
def __post_init__(self):
check_packages()
if self.use_cuda_graph:
check_env()
@dataclass
class FastAttnConfig:
use_fast_attn: bool = False
n_step: int = 20
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
def __post_init__(self):
assert self.n_calib > 0, "n_calib must be greater than 0"
assert self.threshold > 0.0, "threshold must be greater than 0"
@dataclass
class DataParallelConfig:
dp_degree: int = 1
use_cfg_parallel: bool = False
world_size: int = 1
def __post_init__(self):
assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"
# set classifier_free_guidance_degree parallel for split batch
if self.use_cfg_parallel:
self.cfg_degree = 2
else:
self.cfg_degree = 1
assert self.dp_degree * self.cfg_degree <= self.world_size, (
"dp_degree * cfg_degree must be less than or equal to "
"world_size because of classifier free guidance"
)
assert (
self.world_size % (self.dp_degree * self.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
@dataclass
class SequenceParallelConfig:
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
world_size: int = 1
def __post_init__(self):
if self.ulysses_degree is None:
self.ulysses_degree = 1
logger.info(
f"Ulysses degree not set, " f"using default value {self.ulysses_degree}"
)
if self.ring_degree is None:
self.ring_degree = 1
logger.info(
f"Ring degree not set, " f"using default value {self.ring_degree}"
)
self.sp_degree = self.ulysses_degree * self.ring_degree
if not HAS_LONG_CTX_ATTN and self.sp_degree > 1:
raise ImportError(
f"Sequence Parallel kit 'yunchang' not found but "
f"sp_degree is {self.sp_degree}, please set it "
f"to 1 or install 'yunchang' to use it"
)
@dataclass
class TensorParallelConfig:
tp_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1
def __post_init__(self):
assert self.tp_degree >= 1, "tp_degree must greater than 1"
assert (
self.tp_degree <= self.world_size
), "tp_degree must be less than or equal to world_size"
@dataclass
class PipeFusionParallelConfig:
pp_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = (None,)
world_size: int = 1
def __post_init__(self):
assert (
self.pp_degree is not None and self.pp_degree >= 1
), "pipefusion_degree must be set and greater than 1 to use pipefusion"
assert (
self.pp_degree <= self.world_size
), "pipefusion_degree must be less than or equal to world_size"
if self.num_pipeline_patch is None:
self.num_pipeline_patch = self.pp_degree
logger.info(
f"Pipeline patch number not set, "
f"using default value {self.pp_degree}"
)
if self.attn_layer_num_for_pp is not None:
logger.info(
f"attn_layer_num_for_pp set, splitting attention layers"
f"to {self.attn_layer_num_for_pp}"
)
assert len(self.attn_layer_num_for_pp) == self.pp_degree, (
"attn_layer_num_for_pp must have the same "
"length as pp_degree if not None"
)
if self.pp_degree == 1 and self.num_pipeline_patch > 1:
logger.warning(
f"Pipefusion degree is 1, pipeline will not be used,"
f"num_pipeline_patch will be ignored"
)
self.num_pipeline_patch = 1
@dataclass
class ParallelConfig:
dp_config: DataParallelConfig
sp_config: SequenceParallelConfig
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
assert self.dp_config is not None, "dp_config must be set"
assert self.sp_config is not None, "sp_config must be set"
assert self.pp_config is not None, "pp_config must be set"
parallel_world_size = (
self.dp_config.dp_degree
* self.dp_config.cfg_degree
* self.sp_config.sp_degree
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = self.world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {self.world_size}"
)
assert (
world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
assert (
world_size % self.pp_config.pp_degree == 0
), "world_size must be divisible by pp_degree"
assert (
world_size % self.sp_config.sp_degree == 0
), "world_size must be divisible by sp_degree"
assert (
world_size % self.tp_config.tp_degree == 0
), "world_size must be divisible by tp_degree"
self.dp_degree = self.dp_config.dp_degree
self.cfg_degree = self.dp_config.cfg_degree
self.sp_degree = self.sp_config.sp_degree
self.pp_degree = self.pp_config.pp_degree
self.tp_degree = self.tp_config.tp_degree
self.ulysses_degree = self.sp_config.ulysses_degree
self.ring_degree = self.sp_config.ring_degree
@dataclass(frozen=True)
class EngineConfig:
model_config: ModelConfig
runtime_config: RuntimeConfig
parallel_config: ParallelConfig
fast_attn_config: FastAttnConfig
def __post_init__(self):
world_size = self.parallel_config.world_size
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs."""
return dict((field.name, getattr(self, field.name)) for field in fields(self))
@dataclass
class InputConfig:
height: int = 1024
width: int = 1024
num_frames: int = 49
use_resolution_binning: bool = (True,)
batch_size: Optional[int] = None
img_file_path: Optional[str] = None
prompt: Union[str, List[str]] = ""
negative_prompt: Union[str, List[str]] = ""
num_inference_steps: int = 20
max_sequence_length: int = 256
seed: int = 42
output_type: str = "pil"
def __post_init__(self):
if isinstance(self.prompt, list):
assert (
len(self.prompt) == len(self.negative_prompt)
or len(self.negative_prompt) == 0
), "prompts and negative_prompts must have the same quantities"
self.batch_size = self.batch_size or len(self.prompt)
else:
self.batch_size = self.batch_size or 1
assert self.output_type in [
"pil",
"latent",
"pt",
], "output_pil must be either 'pil' or 'latent'"
from .cache_manager import CacheManager
from .long_ctx_attention import xFuserLongContextAttention
from .utils import gpu_timer_decorator
__all__ = [
"CacheManager",
"xFuserLongContextAttention",
"gpu_timer_decorator",
]
from .cache_manager import CacheManager
__all__ = [
"CacheManager",
]
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from xfuser.core.distributed.runtime_state import runtime_state_is_initialized
from xfuser.logger import init_logger
logger = init_logger(__name__)
class CacheEntry:
def __init__(
self,
cache_type: "str",
num_cache_tensors: int = 1,
tensors: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
):
self.cache_type: str = cache_type
if tensors is None:
self.tensors: List[torch.Tensor] = [
None,
] * num_cache_tensors
elif isinstance(tensors, torch.Tensor):
assert (
num_cache_tensors == 1
), "num_cache_tensors must be 1 if you pass a single tensor to tensors argument"
self.tensors = [
tensors,
]
elif isinstance(tensors, List):
assert num_cache_tensors == len(
tensors
), "num_cache_tensors must be equal to num of tensors"
self.tensors = [
tensors,
]
class CacheManager:
supported_layer = ["attn"]
supported_cache_type = ["naive_cache", "sequence_parallel_attn_cache"]
def __init__(
self,
):
self.cache: Dict[Tuple[str, Any], CacheEntry] = {}
def register_cache_entry(
self, layer, layer_type: str, cache_type: str = "naive_cache"
):
if layer_type not in self.supported_layer:
raise ValueError(
f"Layer type: {layer_type} is not supported. Supported layer type: {self.supported_layer}"
)
if cache_type not in self.supported_cache_type:
raise ValueError(
f"Cache type: {cache_type} is not supported. Supported cache type: {self.supported_cache_type}"
)
if self.cache.get((layer_type, layer), None) is not None:
logger.warning(
f"Cache for [layer_type, layer]: [{layer_type}, {layer.__class__}] is already initialized, resetting the cache..."
)
self.cache[layer_type, layer] = CacheEntry(cache_type)
def update_and_get_kv_cache(
self,
new_kv: Union[torch.Tensor, List[torch.Tensor]],
layer: Any,
slice_dim: int = 1,
layer_type: str = "attn",
custom_get_kv: Optional[Callable[[Any, Any, str], torch.Tensor]] = None,
**kwargs,
):
return_list = False
if isinstance(new_kv, List):
return_list = True
new_kv = torch.cat(new_kv, dim=-1)
if custom_get_kv is not None:
return custom_get_kv(self, new_kv, layer, slice_dim, layer_type, **kwargs)
else:
cache_type = self.cache[layer_type, layer].cache_type
if cache_type == "naive_cache":
kv_cache = self._naive_cache_update(
new_kv,
layer=layer,
slice_dim=slice_dim,
layer_type=layer_type,
**kwargs,
)
elif cache_type == "sequence_parallel_attn_cache":
kv_cache = self._sequence_parallel_cache_update(
new_kv,
layer=layer,
slice_dim=slice_dim,
layer_type=layer_type,
**kwargs,
)
if return_list:
return torch.chunk(kv_cache, 2, dim=-1)
else:
return kv_cache
def _naive_cache_update(
self,
new_kv: Union[torch.Tensor, List[torch.Tensor]],
layer,
slice_dim: int = 1,
layer_type: str = "attn",
):
from xfuser.core.distributed.runtime_state import get_runtime_state
if (
not runtime_state_is_initialized()
or get_runtime_state().num_pipeline_patch == 1
or not get_runtime_state().patch_mode
):
kv_cache = new_kv
self.cache[layer_type, layer].tensors[0] = kv_cache
else:
start_token_idx = get_runtime_state().pp_patches_token_start_idx_local[
get_runtime_state().pipeline_patch_idx
]
end_token_idx = get_runtime_state().pp_patches_token_start_idx_local[
get_runtime_state().pipeline_patch_idx + 1
]
kv_cache = self.cache[layer_type, layer].tensors[0]
kv_cache = self._update_kv_in_dim(
kv_cache=kv_cache,
new_kv=new_kv,
dim=slice_dim,
start_idx=start_token_idx,
end_idx=end_token_idx,
)
self.cache[layer_type, layer].tensors[0] = kv_cache
return kv_cache
# work inside ring attn
def _sequence_parallel_cache_update(
self,
new_kv: torch.Tensor,
layer,
slice_dim: int = 1,
layer_type: str = "attn",
):
from xfuser.core.distributed import (
get_ulysses_parallel_world_size,
get_runtime_state,
)
ulysses_world_size = get_ulysses_parallel_world_size()
if (
not runtime_state_is_initialized()
or get_runtime_state().num_pipeline_patch == 1
):
return new_kv
elif not get_runtime_state().patch_mode:
pp_patches_token_num = get_runtime_state().pp_patches_token_num
kv_list = [
kv.split(pp_patches_token_num, dim=slice_dim)
for kv in torch.chunk(new_kv, ulysses_world_size, dim=slice_dim)
]
kv_cache = torch.cat(
[
kv_list[rank][pp_patch_idx]
for rank in range(ulysses_world_size)
for pp_patch_idx in range(len(pp_patches_token_num))
],
dim=slice_dim,
)
self.cache[layer_type, layer].tensors[0] = kv_cache
else:
pp_patches_token_start_idx_local = (
get_runtime_state().pp_patches_token_start_idx_local
)
pp_patch_idx = get_runtime_state().pipeline_patch_idx
start_token_idx = (
ulysses_world_size * pp_patches_token_start_idx_local[pp_patch_idx]
)
end_token_idx = (
ulysses_world_size * pp_patches_token_start_idx_local[pp_patch_idx + 1]
)
# pp_patches_token_num = get_runtime_state().pp_patches_token_num
# start_token_idx = ulysses_world_size * sum(pp_patches_token_num[:get_runtime_state().pipeline_patch_idx])
# end_token_idx = ulysses_world_size * sum(pp_patches_token_num[:get_runtime_state().pipeline_patch_idx + 1])
kv_cache = self.cache[layer_type, layer].tensors[0]
kv_cache = self._update_kv_in_dim(
kv_cache=kv_cache,
new_kv=new_kv,
dim=slice_dim,
start_idx=start_token_idx,
end_idx=end_token_idx,
)
self.cache[layer_type, layer].tensors[0] = kv_cache
return kv_cache
def _update_kv_in_dim(
self,
kv_cache: torch.Tensor,
new_kv: torch.Tensor,
dim: int,
start_idx: int,
end_idx: int,
):
if dim < 0:
dim += kv_cache.dim()
if dim > kv_cache.dim():
raise ValueError(
f"'dim' argument {dim} can not bigger or equal than kv cache dimemsions: {kv_cache.dim()}"
)
if dim == 0:
kv_cache[start_idx:end_idx, ...] = new_kv
elif dim == 1:
kv_cache[:, start_idx:end_idx:, ...] = new_kv
elif dim == 2:
kv_cache[:, :, start_idx:end_idx, ...] = new_kv
elif dim == 3:
kv_cache[:, :, :, start_idx:end_idx, ...] = new_kv
return kv_cache
_CACHE_MGR = CacheManager()
def get_cache_manager():
global _CACHE_MGR
assert _CACHE_MGR is not None, "Cache manager has not been initialized."
return _CACHE_MGR
from .parallel_state import (
get_world_group,
get_dp_group,
get_cfg_group,
get_sp_group,
get_pp_group,
get_pipeline_parallel_world_size,
get_pipeline_parallel_rank,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_data_parallel_world_size,
get_data_parallel_rank,
is_dp_last_group,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_ulysses_parallel_world_size,
get_ulysses_parallel_rank,
get_ring_parallel_world_size,
get_ring_parallel_rank,
init_distributed_environment,
initialize_model_parallel,
model_parallel_is_initialized,
get_tensor_model_parallel_world_size,
)
from .runtime_state import (
get_runtime_state,
runtime_state_is_initialized,
initialize_runtime_state,
)
__all__ = [
"get_world_group",
"get_dp_group",
"get_cfg_group",
"get_sp_group",
"get_pp_group",
"get_pipeline_parallel_world_size",
"get_pipeline_parallel_rank",
"is_pipeline_first_stage",
"is_pipeline_last_stage",
"get_data_parallel_world_size",
"get_data_parallel_rank",
"is_dp_last_group",
"get_classifier_free_guidance_world_size",
"get_classifier_free_guidance_rank",
"get_sequence_parallel_world_size",
"get_sequence_parallel_rank",
"get_ulysses_parallel_world_size",
"get_ulysses_parallel_rank",
"get_ring_parallel_world_size",
"get_ring_parallel_rank",
"init_distributed_environment",
"init_model_parallel_group",
"initialize_model_parallel",
"model_parallel_is_initialized",
"get_runtime_state",
"runtime_state_is_initialized",
"initialize_runtime_state",
]
# Copyright 2024 xDiT team.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple, Union
import pickle
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
import xfuser.envs as envs
from xfuser.logger import init_logger
logger = init_logger(__name__)
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
env_info = envs.PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = ""
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
for key, value in tensor_dict.items():
assert "%" not in key, (
"Avoid having '%' in key "
"as it is used as a separator for nested entries."
)
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(prefix + key, TensorMetadata(device, value.dtype, value.size()))
)
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%"
)
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((prefix + key, value))
return metadata_list, tensor_list
def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
"""
# available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
):
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
assert self.cpu_group is not None
assert self.device_group is not None
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
return self.ranks[0]
@property
def last_rank(self):
"""Return the global rank of the last process in the group"""
return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group + 1) % world_size]
@property
def prev_rank(self):
"""Return the global rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group - 1) % world_size]
@property
def group_next_rank(self):
"""Return the group rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return (rank_in_group + 1) % world_size
@property
def group_prev_rank(self):
"""Return the group rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return (rank_in_group - 1) % world_size
@property
def skip_rank(self):
"""Return the global rank of the process that skip connects with the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(world_size - rank_in_group - 1) % world_size]
@property
def group_skip_rank(self):
"""Return the group rank of the process that skip connects with the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return (world_size - rank_in_group - 1) % world_size
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def all_gather(
self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
input_size = list(input_.size())
input_size[0] *= world_size
output_tensor = torch.empty(
input_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
if dim != 0:
input_size[0] //= world_size
output_tensor = output_tensor.reshape([world_size, ] + input_size)
output_tensor = output_tensor.movedim(0, dim)
if separate_tensors:
tensor_list = [
output_tensor.view(-1)
.narrow(0, input_.numel() * i, input_.numel())
.view_as(input_)
for i in range(world_size)
]
return tensor_list
else:
input_size = list(input_.size())
input_size[dim] = input_size[dim] * world_size
# Reshape
output_tensor = output_tensor.reshape(input_size)
return output_tensor
def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(
input_, src=self.ranks[src], group=self.device_group
)
return input_
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.shm_broadcaster is not None:
assert src == 0, "Shared memory broadcaster only supports src=0"
return self.shm_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list(
[obj], src=self.ranks[src], group=self.cpu_group
)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=self.ranks[src], group=self.cpu_group
)
return recv[0]
def broadcast_object_list(
self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None
):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(
obj_list, src=self.ranks[src], group=self.device_group
)
return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor(
[object_tensor.numel()], dtype=torch.long, device="cpu"
)
# Send object size
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert (
src != self.rank
), "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(
size_tensor, src=self.ranks[src], group=self.cpu_group
)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu",
)
rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.cpu_group
)
assert (
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
src = self.ranks[src]
rank = self.rank
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=metadata_group, async_op=True
)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=group, async_op=True
)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=metadata_group, async_op=True
)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=group, async_op=True
)
async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = self.group_next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(
tensor, dst=self.ranks[dst], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self, src: Optional[int] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = self.group_prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(
tensor, src=self.ranks[src], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
return tensor_dict
def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the rank_in_group of the destination rank."""
if dst is None:
dst = self.group_next_rank
torch.distributed.send(
tensor,
self.ranks[dst],
group=(
self.device_groups[self.rank_in_group % 2]
if self.world_size == 2
else self.device_group
),
)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the rank_in_group of the source rank."""
if src is None:
src = self.group_prev_rank
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(
tensor,
self.ranks[src],
(
self.device_groups[(self.rank_in_group + 1) % 2]
if self.world_size == 2
else self.device_group
),
)
return tensor
def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
class PipelineGroupCoordinator(GroupCoordinator):
"""
available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
world_size: int # size of the group
difference between `local_rank` and `rank_in_group`:
if we have a group of size 4 across two nodes:
Process | Node | Rank | Local Rank | Rank in Group
0 | 0 | 0 | 0 | 0
1 | 0 | 1 | 1 | 1
2 | 1 | 2 | 0 | 2
3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
"""
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
):
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
self.cpu_groups = []
self.device_groups = []
if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1:
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
# when pipeline parallelism is 2, we need to create two groups to avoid
# communication stall.
# *_group_0_1 represents the group for communication from device 0 to
# device 1.
# *_group_1_0 represents the group for communication from device 1 to
# device 0.
elif len(group_ranks[0]) == 2:
for ranks in group_ranks:
device_group_0_1 = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
device_group_1_0 = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo")
cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_groups = [device_group_0_1, device_group_1_0]
self.cpu_groups = [cpu_group_0_1, cpu_group_1_0]
self.device_group = device_group_0_1
self.cpu_group = cpu_group_0_1
assert self.cpu_group is not None
assert self.device_group is not None
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
self.recv_buffer_set: bool = False
self.recv_tasks_queue: List[Tuple[str, int]] = []
self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = []
self.dtype: Optional[torch.dtype] = None
self.num_pipefusion_patches: Optional[int] = None
self.recv_shape: Dict[str, Dict[int, torch.Size]] = {}
self.send_shape: Dict[str, Dict[int, torch.Size]] = {}
self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {}
self.skip_tensor_recv_buffer_set: bool = False
self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = []
self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = []
self.skip_tensor_recv_buffer: Optional[
Union[List[torch.Tensor], torch.Tensor]
] = None
self.skip_device_group = None
for ranks in group_ranks:
skip_device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
if self.rank in ranks:
self.skip_device_group = skip_device_group
assert self.skip_device_group is not None
def reset_buffer(self):
self.recv_tasks_queue = []
self.receiving_tasks = []
self.recv_shape = {}
self.send_shape = {}
self.recv_buffer = {}
self.recv_skip_tasks_queue = []
self.receiving_skip_tasks = []
self.skip_tensor_recv_buffer = {}
def set_config(self, dtype: torch.dtype):
self.dtype = dtype
def set_recv_buffer(
self,
num_pipefusion_patches: int,
patches_shape_list: List[List[int]],
feature_map_shape: List[int],
dtype: torch.dtype,
):
assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object"
assert (
isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1
), "num_pipefusion_patches must be greater than or equal to 1"
self.dtype = dtype
self.num_pipefusion_patches = num_pipefusion_patches
self.recv_buffer = [
torch.zeros(*shape, dtype=self.dtype, device=self.device)
for shape in patches_shape_list
]
self.recv_buffer.append(
torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)
)
self.recv_buffer_set = True
def set_extra_tensors_recv_buffer(
self,
name: str,
shape: List[int],
num_buffers: int = 1,
dtype: torch.dtype = torch.float16,
):
self.extra_tensors_recv_buffer[name] = [
torch.zeros(*shape, dtype=dtype, device=self.device)
for _ in range(num_buffers)
]
def _check_shape_and_buffer(
self,
tensor_send_to_next=None,
recv_prev=False,
name: Optional[str] = None,
segment_idx: int = 0,
):
send_flag = False
name = name or "latent"
if tensor_send_to_next is not None:
shape_list = self.send_shape.get(name, None)
if shape_list is None:
self.send_shape[name] = {segment_idx: tensor_send_to_next.shape}
send_flag = True
elif shape_list.get(segment_idx, None) is None:
self.send_shape[name][segment_idx] = tensor_send_to_next.shape
send_flag = True
recv_flag = False
if recv_prev:
shape_list = self.recv_shape.get(name, None)
if shape_list is None:
recv_flag = True
elif shape_list.get(segment_idx, None) is None:
recv_flag = True
recv_prev_shape = self._communicate_shapes(
tensor_send_to_next=tensor_send_to_next if send_flag else None,
recv_prev=recv_flag,
)
if recv_flag:
if self.recv_shape.get(name, None) is None:
self.recv_shape[name] = {segment_idx: recv_prev_shape}
else:
self.recv_shape[name][segment_idx] = recv_prev_shape
if self.recv_buffer.get(name, None) is None:
self.recv_buffer[name] = {
segment_idx: torch.zeros(
recv_prev_shape, device=self.device, dtype=self.dtype
)
}
else:
if self.recv_buffer[name].get(segment_idx, None) is not None:
logger.warning(
f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating..."
)
self.recv_buffer[name][segment_idx] = torch.zeros(
recv_prev_shape, device=self.device, dtype=self.dtype
)
def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
"""
ops = []
if recv_prev:
recv_prev_dim_tensor = torch.empty(
(1), device=self.device, dtype=torch.int64
)
recv_prev_dim_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_prev_dim_tensor,
self.prev_rank,
self.device_group,
)
ops.append(recv_prev_dim_op)
if tensor_send_to_next is not None:
send_next_dim_tensor = torch.tensor(
tensor_send_to_next.dim(), device=self.device, dtype=torch.int64
)
send_next_dim_op = torch.distributed.P2POp(
torch.distributed.isend,
send_next_dim_tensor,
self.next_rank,
self.device_group,
)
ops.append(send_next_dim_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
torch.cuda.synchronize()
ops = []
recv_prev_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty(
torch.Size(recv_prev_dim_tensor), device=self.device, dtype=torch.int64
)
recv_prev_shape_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_prev_shape_tensor,
self.prev_rank,
self.device_group,
)
ops.append(recv_prev_shape_op)
if tensor_send_to_next is not None:
send_next_shape_tensor = torch.tensor(
tensor_send_to_next.size(), device=self.device, dtype=torch.int64
)
send_next_shape_op = torch.distributed.P2POp(
torch.distributed.isend,
send_next_shape_tensor,
self.next_rank,
self.device_group,
)
ops.append(send_next_shape_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
recv_prev_shape = [0, 0, 0]
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor
return torch.Size(recv_prev_shape)
def pipeline_send(
self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1
) -> None:
tensor = tensor.contiguous()
self._check_shape_and_buffer(
tensor_send_to_next=tensor, name=name, segment_idx=segment_idx
)
self._pipeline_isend(tensor).wait()
def pipeline_isend(
self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1
) -> None:
tensor = tensor.contiguous()
self._check_shape_and_buffer(
tensor_send_to_next=tensor, name=name, segment_idx=segment_idx
)
self._pipeline_isend(tensor)
def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor:
name = name or "latent"
self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx)
self._pipeline_irecv(self.recv_buffer[name][idx]).wait()
return self.recv_buffer[name][idx]
def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"):
name = name or "latent"
self.recv_tasks_queue.append((name, idx))
def recv_next(self):
if len(self.recv_tasks_queue) == 0:
raise ValueError("No more tasks to receive")
elif len(self.recv_tasks_queue) > 0:
name, idx = self.recv_tasks_queue.pop(0)
self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx)
self.receiving_tasks.append(
(self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx)
)
def get_pipeline_recv_data(
self, idx: int = -1, name: str = "latent"
) -> torch.Tensor:
assert (
len(self.receiving_tasks) > 0
), "No tasks to receive, call add_pipeline_recv_task first"
receiving_task = self.receiving_tasks.pop(0)
receiving_task[0].wait()
assert (
receiving_task[1] == name and receiving_task[2] == idx
), "Received tensor does not match the requested"
return self.recv_buffer[name][idx]
def _pipeline_irecv(self, tensor: torch.tensor):
return torch.distributed.irecv(
tensor,
src=self.prev_rank,
group=(
self.device_groups[(self.rank_in_group + 1) % 2]
if self.world_size == 2
else self.device_group
),
)
def _pipeline_isend(self, tensor: torch.tensor):
return torch.distributed.isend(
tensor,
dst=self.next_rank,
group=(
self.device_groups[self.rank_in_group % 2]
if self.world_size == 2
else self.device_group
),
)
def set_skip_tensor_recv_buffer(
self,
patches_shape_list: List[List[int]],
feature_map_shape: List[int],
):
self.skip_tensor_recv_buffer = [
torch.zeros(*shape, dtype=self.dtype, device=self.device)
for shape in patches_shape_list
]
self.skip_tensor_recv_buffer.append(
torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)
)
self.skip_tensor_recv_buffer_set = True
def pipeline_send_skip(self, tensor: torch.Tensor) -> None:
tensor = tensor.contiguous()
self._pipeline_isend_skip(tensor).wait()
def pipeline_isend_skip(self, tensor: torch.Tensor) -> None:
tensor = tensor.contiguous()
self._pipeline_isend_skip(tensor)
def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor:
self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait()
return self.skip_tensor_recv_buffer[idx]
def add_pipeline_recv_skip_task(self, idx: int = -1):
self.recv_skip_tasks_queue.append(idx)
def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor:
assert (
len(self.receiving_skip_tasks) > 0
), "No tasks to receive, call add_pipeline_recv_skip_task first"
receiving_skip_task = self.receiving_skip_tasks.pop(0)
receiving_skip_task[0].wait()
assert (
receiving_skip_task[2] == idx
), "Received tensor does not match the requested"
return self.skip_tensor_recv_buffer[idx]
def recv_skip_next(self):
if len(self.recv_skip_tasks_queue) == 0:
raise ValueError("No more tasks to receive")
elif len(self.recv_skip_tasks_queue) > 0:
task = self.recv_skip_tasks_queue.pop(0)
idx = task
self.receiving_skip_tasks.append(
(
self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]),
None,
idx,
)
)
def _pipeline_irecv_skip(self, tensor: torch.tensor):
return torch.distributed.irecv(
tensor, src=self.skip_rank, group=self.skip_device_group
)
def _pipeline_isend_skip(self, tensor: torch.tensor):
return torch.distributed.isend(
tensor, dst=self.skip_rank, group=self.skip_device_group
)
class SequenceParallelGroupCoordinator(GroupCoordinator):
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
**kwargs,
):
super().__init__(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=torch_distributed_backend,
)
if HAS_LONG_CTX_ATTN:
ulysses_group = kwargs.get("ulysses_group", None)
ring_group = kwargs.get("ring_group", None)
if ulysses_group is None:
raise RuntimeError(
f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator"
)
if ring_group is None:
raise RuntimeError(
f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator"
)
self.ulysses_group = ulysses_group
self.ring_group = ring_group
self.ulysses_world_size = torch.distributed.get_world_size(
self.ulysses_group
)
self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group)
self.ring_world_size = torch.distributed.get_world_size(self.ring_group)
self.ring_rank = torch.distributed.get_rank(self.ring_group)
else:
self.ulysses_world_size = self.ring_world_size = 1
self.ulysses_rank = self.ring_rank = 0
# Copyright 2024 xDiT team.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional
import torch
import torch.distributed
import xfuser.envs as envs
from xfuser.logger import init_logger
from .group_coordinator import (
GroupCoordinator,
PipelineGroupCoordinator,
SequenceParallelGroupCoordinator,
)
from .utils import RankGenerator, generate_masked_orthogonal_rank_groups
env_info = envs.PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
HAS_FLASH_ATTN = env_info["has_flash_attn"]
logger = init_logger(__name__)
_WORLD: Optional[GroupCoordinator] = None
_TP: Optional[GroupCoordinator] = None
_SP: Optional[SequenceParallelGroupCoordinator] = None
_PP: Optional[PipelineGroupCoordinator] = None
_CFG: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None
# * QUERY
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, "world group is not initialized"
return _WORLD
# TP
def get_tp_group() -> GroupCoordinator:
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
# SP
def get_sp_group() -> SequenceParallelGroupCoordinator:
assert _SP is not None, "pipeline model parallel group is not initialized"
return _SP
def get_sequence_parallel_world_size():
"""Return world size for the sequence parallel group."""
return get_sp_group().world_size
def get_sequence_parallel_rank():
"""Return my rank for the sequence parallel group."""
return get_sp_group().rank_in_group
def get_ulysses_parallel_world_size():
return get_sp_group().ulysses_world_size
def get_ulysses_parallel_rank():
return get_sp_group().ulysses_rank
def get_ring_parallel_world_size():
return get_sp_group().ring_world_size
def get_ring_parallel_rank():
return get_sp_group().ring_rank
# PP
def get_pp_group() -> PipelineGroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
def get_pipeline_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
return get_pp_group().world_size
def get_pipeline_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return get_pp_group().rank_in_group
def is_pipeline_first_stage():
"""Return True if in the first pipeline model parallel stage, False otherwise."""
return get_pipeline_parallel_rank() == 0
def is_pipeline_last_stage():
"""Return True if in the last pipeline model parallel stage, False otherwise."""
return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)
# CFG
def get_cfg_group() -> GroupCoordinator:
assert (
_CFG is not None
), "classifier_free_guidance parallel group is not initialized"
return _CFG
def get_classifier_free_guidance_world_size():
"""Return world size for the classifier_free_guidance parallel group."""
return get_cfg_group().world_size
def get_classifier_free_guidance_rank():
"""Return my rank for the classifier_free_guidance parallel group."""
return get_cfg_group().rank_in_group
# DP
def get_dp_group() -> GroupCoordinator:
assert _DP is not None, "pipeline model parallel group is not initialized"
return _DP
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return get_dp_group().world_size
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return get_dp_group().rank_in_group
def is_dp_last_group():
"""Return True if in the last data parallel group, False otherwise."""
return (
get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1)
and get_classifier_free_guidance_rank()
== (get_classifier_free_guidance_world_size() - 1)
and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)
)
# * SET
def init_world_group(
ranks: List[int], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
)
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
)
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
else:
assert (
_WORLD.world_size == torch.distributed.get_world_size()
), "world group already initialized with a different world size"
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (
_DP is not None
and _CFG is not None
and _SP is not None
and _PP is not None
and _TP is not None
)
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
parallel_mode: str,
**kwargs,
) -> GroupCoordinator:
assert parallel_mode in [
"data",
"pipeline",
"tensor",
"sequence",
"classifier_free_guidance",
], f"parallel_mode {parallel_mode} is not supported"
if parallel_mode == "pipeline":
return PipelineGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
)
elif parallel_mode == "sequence":
return SequenceParallelGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
**kwargs,
)
else:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
)
def initialize_model_parallel(
data_parallel_degree: int = 1,
classifier_free_guidance_degree: int = 1,
sequence_parallel_degree: int = 1,
ulysses_degree: int = 1,
ring_degree: int = 1,
tensor_parallel_degree: int = 1,
pipeline_parallel_degree: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
data_parallel_degree: number of data parallelism groups.
classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG)
sequence_parallel_degree: number of GPUs used for sequence parallelism.
ulysses_degree: number of GPUs used for ulysses sequence parallelism.
ring_degree: number of GPUs used for ring sequence parallelism.
tensor_parallel_degree: number of GPUs used for tensor parallelism.
pipeline_parallel_degree: number of GPUs used for pipeline parallelism.
backend: distributed backend of pytorch collective comm.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize
splited batch caused by CFG, and 2 GPUs to parallelize sequence.
dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16.
The present function will create 2 data parallel-groups,
8 CFG group, 8 pipeline-parallel group, and
8 sequence-parallel groups:
2 data-parallel groups:
[g0, g1, g2, g3, g4, g5, g6, g7],
[g8, g9, g10, g11, g12, g13, g14, g15]
8 CFG-parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7],
[g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 sequence-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7],
[g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 pipeline-parallel groups:
[g0, g2], [g4, g6], [g8, g10], [g12, g14],
[g1, g3], [g5, g7], [g9, g11], [g13, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if (
world_size
!= data_parallel_degree
* classifier_free_guidance_degree
* sequence_parallel_degree
* tensor_parallel_degree
* pipeline_parallel_degree
):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_parallel_degree ({tensor_parallel_degree}) x "
f"pipeline_parallel_degree ({pipeline_parallel_degree}) x"
f"sequence_parallel_degree ({sequence_parallel_degree}) x"
f"classifier_free_guidance_degree "
f"({classifier_free_guidance_degree}) x"
f"data_parallel_degree ({data_parallel_degree})"
)
rank_generator: RankGenerator = RankGenerator(
tensor_parallel_degree,
sequence_parallel_degree,
pipeline_parallel_degree,
classifier_free_guidance_degree,
data_parallel_degree,
"tp-sp-pp-cfg-dp",
)
global _DP
assert _DP is None, "data parallel group is already initialized"
_DP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("dp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="data",
)
global _CFG
assert _CFG is None, "classifier_free_guidance group is already initialized"
_CFG = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("cfg"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="classifier_free_guidance",
)
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
_PP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("pp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="pipeline",
)
global _SP
assert _SP is None, "sequence parallel group is already initialized"
# if HAS_LONG_CTX_ATTN and sequence_parallel_degree > 1:
if HAS_LONG_CTX_ATTN:
from yunchang import set_seq_parallel_pg
from yunchang.globals import PROCESS_GROUP
set_seq_parallel_pg(
sp_ulysses_degree=ulysses_degree,
sp_ring_degree=ring_degree,
rank=get_world_group().rank_in_group,
world_size=get_world_group().world_size,
)
_SP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("sp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="sequence",
ulysses_group=PROCESS_GROUP.ULYSSES_PG,
ring_group=PROCESS_GROUP.RING_PG,
)
else:
_SP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("sp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="sequence",
)
global _TP
assert _TP is None, "Tensor parallel group is already initialized"
_TP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("tp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="tensor",
)
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _DP
if _DP:
_DP.destroy()
_DP = None
global _CFG
if _CFG:
_CFG.destroy()
_CFG = None
global _SP
if _SP:
_SP.destroy()
_SP = None
global _TP
if _TP:
_TP.destroy()
_TP = None
global _PP
if _PP:
_PP.destroy()
_PP = None
def destroy_distributed_environment():
global _WORLD
if _WORLD:
_WORLD.destroy()
_WORLD = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
from abc import ABCMeta
import random
from typing import List, Optional, Tuple
import numpy as np
import torch
from diffusers import DiffusionPipeline, CogVideoXPipeline
import torch.distributed
from xfuser.config.config import (
ParallelConfig,
RuntimeConfig,
InputConfig,
EngineConfig,
)
from xfuser.logger import init_logger
from .parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
get_pp_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
init_distributed_environment,
initialize_model_parallel,
model_parallel_is_initialized,
)
logger = init_logger(__name__)
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class RuntimeState(metaclass=ABCMeta):
parallel_config: ParallelConfig
runtime_config: RuntimeConfig
input_config: InputConfig
num_pipeline_patch: int
ready: bool = False
def __init__(self, config: EngineConfig):
self.parallel_config = config.parallel_config
self.runtime_config = config.runtime_config
self.input_config = InputConfig()
self.num_pipeline_patch = self.parallel_config.pp_config.num_pipeline_patch
self.ready = False
self._check_distributed_env(config.parallel_config)
def is_ready(self):
return self.ready
def _check_distributed_env(
self,
parallel_config: ParallelConfig,
):
if not model_parallel_is_initialized():
logger.warning("Model parallel is not initialized, initializing...")
if not torch.distributed.is_initialized():
init_distributed_environment()
initialize_model_parallel(
data_parallel_degree=parallel_config.dp_degree,
classifier_free_guidance_degree=parallel_config.cfg_degree,
sequence_parallel_degree=parallel_config.sp_degree,
ulysses_degree=parallel_config.ulysses_degree,
ring_degree=parallel_config.ring_degree,
tensor_parallel_degree=parallel_config.tp_degree,
pipeline_parallel_degree=parallel_config.pp_degree,
)
def destory_distributed_env(self):
if model_parallel_is_initialized():
destroy_model_parallel()
destroy_distributed_environment()
class DiTRuntimeState(RuntimeState):
patch_mode: bool
pipeline_patch_idx: int
vae_scale_factor: int
vae_scale_factor_spatial: int
vae_scale_factor_temporal: int
backbone_patch_size: int
pp_patches_height: Optional[List[int]]
pp_patches_start_idx_local: Optional[List[int]]
pp_patches_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_start_idx_local: Optional[List[int]]
pp_patches_token_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_num: Optional[List[int]]
max_condition_sequence_length: int
split_text_embed_in_sp: bool
def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
super().__init__(config)
self.patch_mode = False
self.pipeline_patch_idx = 0
self._check_model_and_parallel_config(
pipeline=pipeline, parallel_config=config.parallel_config
)
self.cogvideox = False
if isinstance(pipeline, CogVideoXPipeline):
self._set_cogvideox_parameters(
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
backbone_patch_size=pipeline.transformer.config.patch_size,
backbone_in_channel=pipeline.transformer.config.in_channels,
backbone_inner_dim=pipeline.transformer.config.num_attention_heads
* pipeline.transformer.config.attention_head_dim,
)
else:
self._set_model_parameters(
vae_scale_factor=pipeline.vae_scale_factor,
backbone_patch_size=pipeline.transformer.config.patch_size,
backbone_in_channel=pipeline.transformer.config.in_channels,
backbone_inner_dim=pipeline.transformer.config.num_attention_heads
* pipeline.transformer.config.attention_head_dim,
)
def set_input_parameters(
self,
height: Optional[int] = None,
width: Optional[int] = None,
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
max_condition_sequence_length: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
self.max_condition_sequence_length = max_condition_sequence_length
self.split_text_embed_in_sp = split_text_embed_in_sp
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
if seed is not None and seed != self.input_config.seed:
self.input_config.seed = seed
set_random_seed(seed)
if (
not self.ready
or (height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (batch_size and self.input_config.batch_size != batch_size)
):
self._input_size_change(height, width, batch_size)
self.ready = True
def set_video_input_parameters(
self,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
self.split_text_embed_in_sp = split_text_embed_in_sp
if seed is not None and seed != self.input_config.seed:
self.input_config.seed = seed
set_random_seed(seed)
if (
not self.ready
or (height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (num_frames and self.input_config.num_frames != num_frames)
or (batch_size and self.input_config.batch_size != batch_size)
):
self._video_input_size_change(height, width, num_frames, batch_size)
self.ready = True
def _set_cogvideox_parameters(
self,
vae_scale_factor_spatial: int,
vae_scale_factor_temporal: int,
backbone_patch_size: int,
backbone_inner_dim: int,
backbone_in_channel: int,
):
self.vae_scale_factor_spatial = vae_scale_factor_spatial
self.vae_scale_factor_temporal = vae_scale_factor_temporal
self.backbone_patch_size = backbone_patch_size
self.backbone_inner_dim = backbone_inner_dim
self.backbone_in_channel = backbone_in_channel
self.cogvideox = True
def set_patched_mode(self, patch_mode: bool):
self.patch_mode = patch_mode
self.pipeline_patch_idx = 0
def next_patch(self):
if self.patch_mode:
self.pipeline_patch_idx += 1
if self.pipeline_patch_idx == self.num_pipeline_patch:
self.pipeline_patch_idx = 0
else:
self.pipeline_patch_idx = 0
def _check_model_and_parallel_config(
self,
pipeline: DiffusionPipeline,
parallel_config: ParallelConfig,
):
num_heads = pipeline.transformer.config.num_attention_heads
ulysses_degree = parallel_config.sp_config.ulysses_degree
if num_heads % ulysses_degree != 0 or num_heads < ulysses_degree:
raise RuntimeError(
f"transformer backbone has {num_heads} heads, which is not "
f"divisible by or smaller than ulysses_degree "
f"{ulysses_degree}."
)
def _set_model_parameters(
self,
vae_scale_factor: int,
backbone_patch_size: int,
backbone_inner_dim: int,
backbone_in_channel: int,
):
self.vae_scale_factor = vae_scale_factor
self.backbone_patch_size = backbone_patch_size
self.backbone_inner_dim = backbone_inner_dim
self.backbone_in_channel = backbone_in_channel
def _input_size_change(
self,
height: Optional[int] = None,
width: Optional[int] = None,
batch_size: Optional[int] = None,
):
self.input_config.height = height or self.input_config.height
self.input_config.width = width or self.input_config.width
self.input_config.batch_size = batch_size or self.input_config.batch_size
self._calc_patches_metadata()
self._reset_recv_buffer()
def _video_input_size_change(
self,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
batch_size: Optional[int] = None,
):
self.input_config.height = height or self.input_config.height
self.input_config.width = width or self.input_config.width
self.input_config.num_frames = num_frames or self.input_config.num_frames
self.input_config.batch_size = batch_size or self.input_config.batch_size
if self.cogvideox:
self._calc_cogvideox_patches_metadata()
else:
self._calc_patches_metadata()
self._reset_recv_buffer()
def _calc_patches_metadata(self):
num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
patch_size = self.backbone_patch_size
vae_scale_factor = self.vae_scale_factor
latents_height = self.input_config.height // vae_scale_factor
latents_width = self.input_config.width // vae_scale_factor
if latents_height % num_sp_patches != 0:
raise ValueError(
"The height of the input is not divisible by the number of sequence parallel devices"
)
self.num_pipeline_patch = self.parallel_config.pp_config.num_pipeline_patch
# Pipeline patches
pipeline_patches_height = (
latents_height + self.num_pipeline_patch - 1
) // self.num_pipeline_patch
# make sure pipeline_patches_height is a multiple of (num_sp_patches * patch_size)
pipeline_patches_height = (
(pipeline_patches_height + (num_sp_patches * patch_size) - 1)
// (patch_size * num_sp_patches)
) * (patch_size * num_sp_patches)
# get the number of pipeline that matches patch height requirements
num_pipeline_patch = (
latents_height + pipeline_patches_height - 1
) // pipeline_patches_height
if num_pipeline_patch != self.num_pipeline_patch:
logger.warning(
f"Pipeline patches num changed from "
f"{self.num_pipeline_patch} to {num_pipeline_patch} due "
f"to input size and parallelisation requirements"
)
pipeline_patches_height_list = [
pipeline_patches_height for _ in range(num_pipeline_patch - 1)
]
the_last_pp_patch_height = latents_height - pipeline_patches_height * (
num_pipeline_patch - 1
)
if the_last_pp_patch_height % (patch_size * num_sp_patches) != 0:
raise ValueError(
f"The height of the last pipeline patch is {the_last_pp_patch_height}, "
f"which is not a multiple of (patch_size * num_sp_patches): "
f"{patch_size} * {num_sp_patches}. Please try to adjust 'num_pipeline_patches "
f"or sp_degree argument so that the condition are met "
)
pipeline_patches_height_list.append(the_last_pp_patch_height)
# Sequence parallel patches
# len: sp_degree * num_pipeline_patches
flatten_patches_height = [
pp_patch_height // num_sp_patches
for _ in range(num_sp_patches)
for pp_patch_height in pipeline_patches_height_list
]
flatten_patches_start_idx = [0] + [
sum(flatten_patches_height[:i])
for i in range(1, len(flatten_patches_height) + 1)
]
pp_sp_patches_height = [
flatten_patches_height[
pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches
]
for pp_patch_idx in range(num_pipeline_patch)
]
pp_sp_patches_start_idx = [
flatten_patches_start_idx[
pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches + 1
]
for pp_patch_idx in range(num_pipeline_patch)
]
pp_patches_height = [
sp_patches_height[sp_patch_idx]
for sp_patches_height in pp_sp_patches_height
]
pp_patches_start_idx_local = [0] + [
sum(pp_patches_height[:i]) for i in range(1, len(pp_patches_height) + 1)
]
pp_patches_start_end_idx_global = [
sp_patches_start_idx[sp_patch_idx : sp_patch_idx + 2]
for sp_patches_start_idx in pp_sp_patches_start_idx
]
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size) * (start_idx // patch_size),
(latents_width // patch_size) * (end_idx // patch_size),
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
pp_patches_token_num = [
end - start for start, end in pp_patches_token_start_end_idx_global
]
pp_patches_token_start_idx_local = [
sum(pp_patches_token_num[:i]) for i in range(len(pp_patches_token_num) + 1)
]
self.num_pipeline_patch = num_pipeline_patch
self.pp_patches_height = pp_patches_height
self.pp_patches_start_idx_local = pp_patches_start_idx_local
self.pp_patches_start_end_idx_global = pp_patches_start_end_idx_global
self.pp_patches_token_start_idx_local = pp_patches_token_start_idx_local
self.pp_patches_token_start_end_idx_global = (
pp_patches_token_start_end_idx_global
)
self.pp_patches_token_num = pp_patches_token_num
def _calc_cogvideox_patches_metadata(self):
num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
patch_size = self.backbone_patch_size
vae_scale_factor_spatial = self.vae_scale_factor_spatial
latents_height = self.input_config.height // vae_scale_factor_spatial
latents_width = self.input_config.width // vae_scale_factor_spatial
latents_frames = (
self.input_config.num_frames - 1
) // self.vae_scale_factor_temporal + 1
if latents_height % num_sp_patches != 0:
raise ValueError(
"The height of the input is not divisible by the number of sequence parallel devices"
)
self.num_pipeline_patch = self.parallel_config.pp_config.num_pipeline_patch
# Pipeline patches
pipeline_patches_height = (
latents_height + self.num_pipeline_patch - 1
) // self.num_pipeline_patch
# make sure pipeline_patches_height is a multiple of (num_sp_patches * patch_size)
pipeline_patches_height = (
(pipeline_patches_height + (num_sp_patches * patch_size) - 1)
// (patch_size * num_sp_patches)
) * (patch_size * num_sp_patches)
# get the number of pipeline that matches patch height requirements
num_pipeline_patch = (
latents_height + pipeline_patches_height - 1
) // pipeline_patches_height
if num_pipeline_patch != self.num_pipeline_patch:
logger.warning(
f"Pipeline patches num changed from "
f"{self.num_pipeline_patch} to {num_pipeline_patch} due "
f"to input size and parallelisation requirements"
)
pipeline_patches_height_list = [
pipeline_patches_height for _ in range(num_pipeline_patch - 1)
]
the_last_pp_patch_height = latents_height - pipeline_patches_height * (
num_pipeline_patch - 1
)
if the_last_pp_patch_height % (patch_size * num_sp_patches) != 0:
raise ValueError(
f"The height of the last pipeline patch is {the_last_pp_patch_height}, "
f"which is not a multiple of (patch_size * num_sp_patches): "
f"{patch_size} * {num_sp_patches}. Please try to adjust 'num_pipeline_patches "
f"or sp_degree argument so that the condition are met "
)
pipeline_patches_height_list.append(the_last_pp_patch_height)
# Sequence parallel patches
# len: sp_degree * num_pipeline_patches
flatten_patches_height = [
pp_patch_height // num_sp_patches
for _ in range(num_sp_patches)
for pp_patch_height in pipeline_patches_height_list
]
flatten_patches_start_idx = [0] + [
sum(flatten_patches_height[:i])
for i in range(1, len(flatten_patches_height) + 1)
]
pp_sp_patches_height = [
flatten_patches_height[
pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches
]
for pp_patch_idx in range(num_pipeline_patch)
]
pp_sp_patches_start_idx = [
flatten_patches_start_idx[
pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches + 1
]
for pp_patch_idx in range(num_pipeline_patch)
]
pp_patches_height = [
sp_patches_height[sp_patch_idx]
for sp_patches_height in pp_sp_patches_height
]
pp_patches_start_idx_local = [0] + [
sum(pp_patches_height[:i]) for i in range(1, len(pp_patches_height) + 1)
]
pp_patches_start_end_idx_global = [
sp_patches_start_idx[sp_patch_idx : sp_patch_idx + 2]
for sp_patches_start_idx in pp_sp_patches_start_idx
]
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size) * (start_idx // patch_size),
(latents_width // patch_size) * (end_idx // patch_size),
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
pp_patches_token_num = [
end - start for start, end in pp_patches_token_start_end_idx_global
]
pp_patches_token_start_idx_local = [
sum(pp_patches_token_num[:i]) for i in range(len(pp_patches_token_num) + 1)
]
self.num_pipeline_patch = num_pipeline_patch
self.pp_patches_height = pp_patches_height
self.pp_patches_start_idx_local = pp_patches_start_idx_local
self.pp_patches_start_end_idx_global = pp_patches_start_end_idx_global
self.pp_patches_token_start_idx_local = pp_patches_token_start_idx_local
self.pp_patches_token_start_end_idx_global = (
pp_patches_token_start_end_idx_global
)
self.pp_patches_token_num = pp_patches_token_num
def _reset_recv_buffer(self):
get_pp_group().reset_buffer()
get_pp_group().set_config(dtype=self.runtime_config.dtype)
def _reset_recv_skip_buffer(self, num_blocks_per_stage):
batch_size = self.input_config.batch_size
batch_size = batch_size * (2 // self.parallel_config.cfg_degree)
hidden_dim = self.backbone_inner_dim
num_patches_tokens = [
end - start for start, end in self.pp_patches_token_start_end_idx_global
]
patches_shape = [
[num_blocks_per_stage, batch_size, tokens, hidden_dim]
for tokens in num_patches_tokens
]
feature_map_shape = [
num_blocks_per_stage,
batch_size,
sum(num_patches_tokens),
hidden_dim,
]
# reset pipeline communicator buffer
get_pp_group().set_skip_tensor_recv_buffer(
patches_shape_list=patches_shape,
feature_map_shape=feature_map_shape,
)
# _RUNTIME: Optional[RuntimeState] = None
# TODO: change to RuntimeState after implementing the unet
_RUNTIME: Optional[DiTRuntimeState] = None
def runtime_state_is_initialized():
return _RUNTIME is not None
def get_runtime_state():
assert _RUNTIME is not None, "Runtime state has not been initialized."
return _RUNTIME
def initialize_runtime_state(pipeline: DiffusionPipeline, engine_config: EngineConfig):
global _RUNTIME
if _RUNTIME is not None:
logger.warning(
"Runtime state is already initialized, reinitializing with pipeline..."
)
if hasattr(pipeline, "transformer"):
_RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config)
from typing import List
def generate_masked_orthogonal_rank_groups(
world_size: int, parallel_size: List[int], mask: List[bool]
) -> List[List[int]]:
"""Generate orthogonal parallel groups based on the parallel size and mask.
Arguments:
world_size (int): world size
parallel_size (List[int]):
The parallel size of each orthogonal parallel type. For example, if
tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
mask (List[bool]):
The mask controls which parallel methods the generated groups represent. If mask[i] is
True, it means the generated group contains the i-th parallelism method. For example,
if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
generated group is the `pp` group.
Algorithm:
For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
local_rank satisfy the following equation:
global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
tp_rank \in [0, tp_size)
dp_rank \in [0, dp_size)
pp_rank \in [0, pp_size)
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
dp_group_index = tp_rank + pp_rank * tp_size (2)
So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
equation (1).
This function solve this math problem.
For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
and the mask = [False, True, False]. Then,
dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
...
dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
...
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
"""
def prefix_product(a: List[int], init=1) -> List[int]:
r = [init]
for v in a:
init = init * v
r.append(init)
return r
def inner_product(a: List[int], b: List[int]) -> int:
return sum([x * y for x, y in zip(a, b)])
def decompose(index, shape, stride=None):
"""
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
"""
if stride is None:
stride = prefix_product(shape)
idx = [(index // d) % s for s, d in zip(shape, stride)]
# stride is a prefix_product result. And the value of stride[-1]
# is not used.
assert (
sum([x * y for x, y in zip(idx, stride[:-1])]) == index
), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
return idx
masked_shape = [s for s, m in zip(parallel_size, mask) if m]
unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]
global_stride = prefix_product(parallel_size)
masked_stride = [d for d, m in zip(global_stride, mask) if m]
unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]
group_size = prefix_product(masked_shape)[-1]
num_of_group = world_size // group_size
ranks = []
for group_index in range(num_of_group):
# get indices from unmaksed for group_index.
decomposed_group_idx = decompose(group_index, unmasked_shape)
rank = []
for rank_in_group in range(group_size):
# get indices from masked for rank_in_group.
decomposed_rank_idx = decompose(rank_in_group, masked_shape)
rank.append(
inner_product(decomposed_rank_idx, masked_stride)
+ inner_product(decomposed_group_idx, unmasked_stride)
)
ranks.append(rank)
return ranks
class RankGenerator(object):
def __init__(
self,
tp: int,
sp: int,
pp: int,
cfg: int,
dp: int,
order: str,
rank_offset: int = 0,
) -> None:
self.tp = tp
self.sp = sp
self.pp = pp
self.cfg = cfg
self.dp = dp
self.rank_offset = rank_offset
self.world_size = tp * sp * pp * cfg * dp
self.name_to_size = {
"tp": self.tp,
"sp": self.sp,
"pp": self.pp,
"cfg": self.cfg,
"dp": self.dp,
}
order = order.lower()
for name in self.name_to_size.keys():
if name not in order and self.name_to_size[name] != 1:
raise RuntimeError(
f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})."
)
elif name not in order:
order = order + "-" + name
self.order = order
self.ordered_size = []
for token in order.split("-"):
self.ordered_size.append(self.name_to_size[token])
def get_mask(self, order: str, token: str):
ordered_token = order.split("-")
token = token.split("-")
mask = [False] * len(ordered_token)
for t in token:
mask[ordered_token.index(t)] = True
return mask
def get_ranks(self, token):
"""Get rank group by input token.
Arguments:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
independent_ep (bool: True):
This flag controls whether we treat EP and DP independently.
EP shares ranks with DP, if we want to get ranks related to
EP, we should set the flag. For example, get_ranks('dp', True)
will get DP modulo EP group, and get_ranks('dp', False) will
get full DP group.
"""
mask = self.get_mask(self.order, token)
ranks = generate_masked_orthogonal_rank_groups(
self.world_size, self.ordered_size, mask
)
if self.rank_offset > 0:
for rank_group in ranks:
for i in range(len(rank_group)):
rank_group[i] += self.rank_offset
return ranks
from .fast_attn_state import (
get_fast_attn_state,
get_fast_attn_enable,
get_fast_attn_step,
get_fast_attn_calib,
get_fast_attn_threshold,
get_fast_attn_window_size,
get_fast_attn_coco_path,
get_fast_attn_use_cache,
get_fast_attn_config_file,
get_fast_attn_layer_name,
initialize_fast_attn_state,
)
from .attn_layer import (
FastAttnMethod,
xFuserFastAttention,
)
from .utils import fast_attention_compression
__all__ = [
"get_fast_attn_state",
"get_fast_attn_enable",
"get_fast_attn_step",
"get_fast_attn_calib",
"get_fast_attn_threshold",
"get_fast_attn_window_size",
"get_fast_attn_coco_path",
"get_fast_attn_use_cache",
"get_fast_attn_config_file",
"get_fast_attn_layer_name",
"initialize_fast_attn_state",
"xFuserFastAttention",
"FastAttnMethod",
"fast_attention_compression",
]
# Copyright 2024 xDiT team.
# Adapted from
# https://github.com/thu-nics/DiTFastAttn/blob/main/modules/fast_attn_processor.py
# Copyright (c) 2024 NICS-EFC Lab of Tsinghua University.
import torch
from diffusers.models.attention_processor import Attention
from typing import Optional
import torch.nn.functional as F
import flash_attn
from enum import Flag, auto
from .fast_attn_state import get_fast_attn_window_size
class FastAttnMethod(Flag):
FULL_ATTN = auto()
RESIDUAL_WINDOW_ATTN = auto()
OUTPUT_SHARE = auto()
CFG_SHARE = auto()
RESIDUAL_WINDOW_ATTN_CFG_SHARE = RESIDUAL_WINDOW_ATTN | CFG_SHARE
FULL_ATTN_CFG_SHARE = FULL_ATTN | CFG_SHARE
def has(self, method: "FastAttnMethod"):
return bool(self & method)
class xFuserFastAttention:
window_size: list[int] = [-1, -1]
steps_method: list[FastAttnMethod] = []
cond_first: bool = False
need_compute_residual: list[bool] = []
need_cache_output: bool = False
def __init__(
self,
steps_method: list[FastAttnMethod] = [],
cond_first: bool = False,
):
window_size = get_fast_attn_window_size()
self.window_size = [window_size, window_size]
self.steps_method = steps_method
# CFG order flag (conditional first or unconditional first)
self.cond_first = cond_first
self.need_compute_residual = self.compute_need_compute_residual()
self.need_cache_output = True
def set_methods(
self,
steps_method: list[FastAttnMethod],
selecting: bool = False,
):
self.steps_method = steps_method
if selecting:
if len(self.need_compute_residual) != len(self.steps_method):
self.need_compute_residual = [False] * len(self.steps_method)
else:
self.need_compute_residual = self.compute_need_compute_residual()
def compute_need_compute_residual(self):
"""Check at which timesteps do we need to compute the full-window residual of this attention module"""
need_compute_residual = []
for i, method in enumerate(self.steps_method):
need = False
if method.has(FastAttnMethod.FULL_ATTN):
for j in range(i + 1, len(self.steps_method)):
if self.steps_method[j].has(FastAttnMethod.RESIDUAL_WINDOW_ATTN):
# If encountered a step that conduct WA-RS,
# this step needs the residual computation
need = True
if self.steps_method[j].has(FastAttnMethod.FULL_ATTN):
# If encountered another step using the `full-attn` strategy,
# this step doesn't need the residual computation
break
need_compute_residual.append(need)
return need_compute_residual
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
# Before calculating the attention, prepare the related parameters
method = self.steps_method[attn.stepi] if attn.stepi < len(self.steps_method) else FastAttnMethod.FULL_ATTN
need_compute_residual = self.need_compute_residual[attn.stepi] if attn.stepi < len(self.need_compute_residual) else False
# Run the forward method according to the selected strategy
residual = hidden_states
if method.has(FastAttnMethod.OUTPUT_SHARE):
hidden_states = attn.cached_output
else:
if method.has(FastAttnMethod.CFG_SHARE):
# Directly use the unconditional branch's attention output
# as the conditional branch's attention output
batch_size = hidden_states.shape[0]
if self.cond_first:
hidden_states = hidden_states[: batch_size // 2]
else:
hidden_states = hidden_states[batch_size // 2 :]
if encoder_hidden_states is not None:
if self.cond_first:
encoder_hidden_states = encoder_hidden_states[: batch_size // 2]
else:
encoder_hidden_states = encoder_hidden_states[batch_size // 2 :]
if attention_mask is not None:
if self.cond_first:
attention_mask = attention_mask[: batch_size // 2]
else:
attention_mask = attention_mask[batch_size // 2 :]
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
if attention_mask is not None:
assert (
method.has(FastAttnMethod.RESIDUAL_WINDOW_ATTN) == False
), "Attention mask is not supported in windowed attention"
hidden_states = F.scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
).transpose(1, 2)
elif method.has(FastAttnMethod.FULL_ATTN):
all_hidden_states = flash_attn.flash_attn_func(query, key, value)
if need_compute_residual:
# Compute the full-window attention residual
w_hidden_states = flash_attn.flash_attn_func(query, key, value, window_size=self.window_size)
window_residual = all_hidden_states - w_hidden_states
if method.has(FastAttnMethod.CFG_SHARE):
window_residual = torch.cat([window_residual, window_residual], dim=0)
# Save the residual for usage in follow-up steps
attn.cached_residual = window_residual
hidden_states = all_hidden_states
elif method.has(FastAttnMethod.RESIDUAL_WINDOW_ATTN):
w_hidden_states = flash_attn.flash_attn_func(query, key, value, window_size=self.window_size)
hidden_states = w_hidden_states + attn.cached_residual[:batch_size].view_as(w_hidden_states)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if method.has(FastAttnMethod.CFG_SHARE):
hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
if self.need_cache_output:
attn.cached_output = hidden_states
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# After been call once, add the timestep index of this attention module by 1
attn.stepi += 1
return hidden_states
# TODO: Implement classes to support DiTFastAttn in different diffusion models
class xFuserJointFastAttention(xFuserFastAttention):
pass
class xFuserFluxFastAttention(xFuserFastAttention):
pass
class xFuserHunyuanFastAttention(xFuserFastAttention):
pass
from typing import Optional
from diffusers import DiffusionPipeline
from xfuser.config.config import (
ParallelConfig,
RuntimeConfig,
InputConfig,
FastAttnConfig,
EngineConfig,
)
from xfuser.logger import init_logger
logger = init_logger(__name__)
class FastAttnState:
enable: bool = False
n_step: int = 20
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
config_file: str
layer_name: str
def __init__(self, pipe: DiffusionPipeline, config: FastAttnConfig):
self.enable = config.use_fast_attn
if self.enable:
self.n_step = config.n_step
self.n_calib = config.n_calib
self.threshold = config.threshold
self.window_size = config.window_size
self.coco_path = config.coco_path
self.use_cache = config.use_cache
self.config_file = self.config_file_path(pipe, config)
self.layer_name = self.attn_name_to_wrap(pipe)
def config_file_path(self, pipe: DiffusionPipeline, config: FastAttnConfig):
"""Return the config file path."""
return f"cache/{pipe.config._name_or_path.replace('/', '_')}_{config.n_step}_{config.n_calib}_{config.threshold}_{config.window_size}.json"
def attn_name_to_wrap(self, pipe: DiffusionPipeline):
"""Return the attr name of attention layer to wrap."""
names = ["attn1", "attn"] # names of self attention layer
assert hasattr(pipe, "transformer"), "transformer is not found in pipeline."
assert hasattr(pipe.transformer, "transformer_blocks"), "transformer_blocks is not found in pipeline."
block = pipe.transformer.transformer_blocks[0]
for name in names:
if hasattr(block, name):
return name
raise AttributeError(f"Attention layer name is not found in {names}.")
_FASTATTN: Optional[FastAttnState] = None
def get_fast_attn_state() -> FastAttnState:
# assert _FASTATTN is not None, "FastAttn state is not initialized"
return _FASTATTN
def get_fast_attn_enable() -> bool:
"""Return whether fast attention is enabled."""
if get_fast_attn_state() is None:
return False
return get_fast_attn_state().enable
def get_fast_attn_step() -> int:
"""Return the fast attention step."""
assert get_fast_attn_state() is not None, "FastAttn state is not initialized"
return get_fast_attn_state().n_step
def get_fast_attn_calib() -> int:
"""Return the fast attention calibration."""
assert get_fast_attn_state() is not None, "FastAttn state is not initialized"
return get_fast_attn_state().n_calib
def get_fast_attn_threshold() -> float:
"""Return the fast attention threshold."""
return get_fast_attn_state().threshold
def get_fast_attn_window_size() -> int:
"""Return the fast attention window size."""
return get_fast_attn_state().window_size
def get_fast_attn_coco_path() -> Optional[str]:
"""Return the fast attention coco path."""
return get_fast_attn_state().coco_path
def get_fast_attn_use_cache() -> bool:
"""Return the fast attention use_cache."""
return get_fast_attn_state().use_cache
def get_fast_attn_config_file() -> str:
"""Return the fast attention config file."""
return get_fast_attn_state().config_file
def get_fast_attn_layer_name() -> str:
"""Return the fast attention layer name."""
return get_fast_attn_state().layer_name
def initialize_fast_attn_state(pipeline: DiffusionPipeline, single_config: FastAttnConfig):
global _FASTATTN
if _FASTATTN is not None:
logger.warning("FastAttn state is already initialized, reinitializing with pipeline...")
_FASTATTN = FastAttnState(pipe=pipeline, config=single_config)
# Copyright 2024 xDiT team.
# Adapted from
# https://github.com/thu-nics/DiTFastAttn/blob/main/dit_fast_attention.py
# Copyright (c) 2024 NICS-EFC Lab of Tsinghua University.
import torch
from xfuser.core.distributed import (
get_dp_group,
get_data_parallel_rank,
)
from diffusers import DiffusionPipeline
from diffusers.models.transformers.transformer_2d import Transformer2DModel
from xfuser.model_executor.layers.attention_processor import xFuserAttentionBaseWrapper
from collections import Counter
import os
import json
import numpy as np
from .fast_attn_state import (
get_fast_attn_step,
get_fast_attn_calib,
get_fast_attn_threshold,
get_fast_attn_coco_path,
get_fast_attn_use_cache,
get_fast_attn_config_file,
get_fast_attn_layer_name,
)
from .attn_layer import (
xFuserFastAttention,
FastAttnMethod,
)
from xfuser.logger import init_logger
logger = init_logger(__name__)
def save_config_file(step_methods, file_path):
folder = os.path.dirname(file_path)
if not os.path.exists(folder):
os.makedirs(folder)
format_data = {
f"block{blocki}": {f"step{stepi}": method.name for stepi, method in enumerate(methods)}
for blocki, methods in enumerate(step_methods)
}
with open(file_path, "w") as file:
json.dump(format_data, file, indent=2)
def load_config_file(file_path):
with open(file_path, "r") as file:
format_data = json.load(file)
steps_methods = [[FastAttnMethod[method] for method in format_method.values()] for format_method in format_data.values()]
return steps_methods
def compression_loss(a, b):
ls = []
if a.__class__.__name__ == "Transformer2DModelOutput":
a = [a.sample]
b = [b.sample]
weight = torch.tensor(0.0)
for ai, bi in zip(a, b):
if isinstance(ai, torch.Tensor):
weight += ai.numel()
diff = (ai - bi) / (torch.max(ai, bi) + 1e-6)
loss = diff.abs().clip(0, 10).mean()
ls.append(loss)
weight_sum = get_dp_group().all_reduce(weight.clone().to(ai.device))
local_loss = (weight / weight_sum) * (sum(ls) / len(ls))
global_loss = get_dp_group().all_reduce(local_loss.clone().to(ai.device)).item()
return global_loss
def transformer_forward_pre_hook(m: Transformer2DModel, args, kwargs):
attn_name = get_fast_attn_layer_name()
now_stepi = getattr(m.transformer_blocks[0], attn_name).stepi
# batch_size = get_fast_attn_calib()
# dp_degree =
for blocki, block in enumerate(m.transformer_blocks):
# Set `need_compute_residual` to False to avoid the process of trying different
# compression strategies to override the saved residual.
fast_attn = getattr(block, attn_name).processor.fast_attn
fast_attn.need_compute_residual[now_stepi] = False
fast_attn.need_cache_output = False
raw_outs = m.forward(*args, **kwargs)
for blocki, block in enumerate(m.transformer_blocks):
if now_stepi == 0:
continue
fast_attn = getattr(block, attn_name).processor.fast_attn
method_candidates = [
FastAttnMethod.OUTPUT_SHARE,
FastAttnMethod.RESIDUAL_WINDOW_ATTN_CFG_SHARE,
FastAttnMethod.RESIDUAL_WINDOW_ATTN,
FastAttnMethod.FULL_ATTN_CFG_SHARE,
]
selected_method = FastAttnMethod.FULL_ATTN
for method in method_candidates:
# Try compress this attention using `method`
fast_attn.steps_method[now_stepi] = method
# Set the timestep index of every layer back to now_stepi
# (which are increased by one in every forward)
for _block in m.transformer_blocks:
for layer in _block.children():
if isinstance(layer, xFuserAttentionBaseWrapper):
layer.stepi = now_stepi
# Compute the overall transformer output
outs = m.forward(*args, **kwargs)
loss = compression_loss(raw_outs, outs)
threshold = m.loss_thresholds[now_stepi][blocki]
if loss < threshold:
selected_method = method
break
fast_attn.steps_method[now_stepi] = selected_method
del loss, outs
del raw_outs
# Set the timestep index of every layer back to now_stepi
# (which are increased by one in every forward)
for _block in m.transformer_blocks:
for layer in _block.children():
if isinstance(layer, xFuserAttentionBaseWrapper):
layer.stepi = now_stepi
for blocki, block in enumerate(m.transformer_blocks):
# During the compression plan decision process,
# we set the `need_compute_residual` property of all attention modules to `True`,
# so that all full attention modules will save its residual for convenience.
# The residual will be saved in the follow-up forward call.
fast_attn = getattr(block, attn_name).processor.fast_attn
fast_attn.need_compute_residual[now_stepi] = True
fast_attn.need_cache_output = True
def select_methods(pipe: DiffusionPipeline):
blocks = pipe.transformer.transformer_blocks
transformer: Transformer2DModel = pipe.transformer
attn_name = get_fast_attn_layer_name()
n_steps = get_fast_attn_step()
# reset all processors
for block in blocks:
fast_attn: xFuserFastAttention = getattr(block, attn_name).processor.fast_attn
fast_attn.set_methods(
[FastAttnMethod.FULL_ATTN] * n_steps,
selecting=True,
)
# Setup loss threshold for each timestep and layer
loss_thresholds = []
for step_i in range(n_steps):
sub_list = []
for blocki in range(len(blocks)):
threshold_i = (blocki + 1) / len(blocks) * get_fast_attn_threshold()
sub_list.append(threshold_i)
loss_thresholds.append(sub_list)
# calibration
hook = transformer.register_forward_pre_hook(transformer_forward_pre_hook, with_kwargs=True)
transformer.loss_thresholds = loss_thresholds
seed = 3
guidance_scale = 4.5
if not os.path.exists(get_fast_attn_coco_path()):
raise FileNotFoundError(f"File {get_fast_attn_coco_path()} not found")
with open(get_fast_attn_coco_path(), "r") as file:
mscoco_anno = json.load(file)
np.random.seed(seed)
slice_ = np.random.choice(mscoco_anno["annotations"], get_fast_attn_calib())
calib_x = [d["caption"] for d in slice_]
pipe(
prompt=calib_x,
num_inference_steps=n_steps,
generator=torch.manual_seed(seed),
output_type="latent",
negative_prompt="",
return_dict=False,
guidance_scale=guidance_scale,
)
hook.remove()
del transformer.loss_thresholds
blocks_methods = [getattr(block, attn_name).processor.fast_attn.steps_method for block in blocks]
return blocks_methods
def set_methods(
pipe: DiffusionPipeline,
blocks_methods: list,
):
attn_name = get_fast_attn_layer_name()
blocks = pipe.transformer.transformer_blocks
for blocki, block in enumerate(blocks):
getattr(block, attn_name).processor.fast_attn.set_methods(blocks_methods[blocki])
def statistics(pipe: DiffusionPipeline):
attn_name = get_fast_attn_layer_name()
blocks = pipe.transformer.transformer_blocks
counts = Counter([method for block in blocks for method in getattr(block, attn_name).processor.fast_attn.steps_method])
total = sum(counts.values())
for k, v in counts.items():
logger.info(f"{attn_name} {k} {v/total}")
def fast_attention_compression(pipe: DiffusionPipeline):
config_file = get_fast_attn_config_file()
logger.info(f"config file is {config_file}")
if get_fast_attn_use_cache() and os.path.exists(config_file):
logger.info(f"load config file {config_file} as DiTFastAttn compression methods.")
blocks_methods = load_config_file(config_file)
else:
if get_fast_attn_use_cache():
logger.warning(f"config file {config_file} not found.")
logger.info("start to select DiTFastAttn compression methods.")
blocks_methods = select_methods(pipe)
if get_data_parallel_rank() == 0:
save_config_file(blocks_methods, config_file)
logger.info(f"save DiTFastAttn compression methods to {config_file}")
set_methods(pipe, blocks_methods)
statistics(pipe)
from .hybrid import xFuserLongContextAttention
from .ulysses import xFuserUlyssesAttention
__all__ = [
"xFuserLongContextAttention",
"xFuserUlyssesAttention",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment