"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "168a88e57070871eef5a9fcdad3ed1a4d708d7bd"
Unverified Commit 8d45f219 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix Context Parallel validation checks (#12446)



* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0fd58c77
...@@ -44,11 +44,16 @@ class ContextParallelConfig: ...@@ -44,11 +44,16 @@ class ContextParallelConfig:
Args: Args:
ring_degree (`int`, *optional*, defaults to `1`): ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
total number of devices in the context parallel mesh. attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`): ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
total number of devices in the context parallel mesh. local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
good interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`): convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability. Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`): rotate_method (`str`, *optional*, defaults to `"allgather"`):
...@@ -79,29 +84,46 @@ class ContextParallelConfig: ...@@ -79,29 +84,46 @@ class ContextParallelConfig:
if self.ulysses_degree is None: if self.ulysses_degree is None:
self.ulysses_degree = 1 self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
@property
def mesh_shape(self) -> Tuple[int, int]:
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
self._device = device self._device = device
self._mesh = mesh self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1 if self.ulysses_degree * self.ring_degree > world_size:
if self.ulysses_degree is None: raise ValueError(
self.ulysses_degree = 1 f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
) )
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten() self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None: self._ring_mesh = self._mesh["ring"]
self._ring_mesh = self._mesh["ring"] self._ulysses_mesh = self._mesh["ulysses"]
if self._ulysses_mesh is None: self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_mesh = self._mesh["ulysses"] self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass @dataclass
...@@ -119,7 +141,7 @@ class ParallelConfig: ...@@ -119,7 +141,7 @@ class ParallelConfig:
_rank: int = None _rank: int = None
_world_size: int = None _world_size: int = None
_device: torch.device = None _device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None _mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup( def setup(
self, self,
...@@ -127,14 +149,14 @@ class ParallelConfig: ...@@ -127,14 +149,14 @@ class ParallelConfig:
world_size: int, world_size: int,
device: torch.device, device: torch.device,
*, *,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
): ):
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
self._device = device self._device = device
self._cp_mesh = cp_mesh self._mesh = mesh
if self.context_parallel_config is not None: if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh) self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True) @dataclass(frozen=True)
......
...@@ -220,7 +220,7 @@ class _AttentionBackendRegistry: ...@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
_backends = {} _backends = {}
_constraints = {} _constraints = {}
_supported_arg_names = {} _supported_arg_names = {}
_supports_context_parallel = {} _supports_context_parallel = set()
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS _checks_enabled = DIFFUSERS_ATTN_CHECKS
...@@ -237,7 +237,9 @@ class _AttentionBackendRegistry: ...@@ -237,7 +237,9 @@ class _AttentionBackendRegistry:
cls._backends[backend] = func cls._backends[backend] = func
cls._constraints[backend] = constraints or [] cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel if supports_context_parallel:
cls._supports_context_parallel.add(backend.value)
return func return func
return decorator return decorator
...@@ -251,15 +253,12 @@ class _AttentionBackendRegistry: ...@@ -251,15 +253,12 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys()) return list(cls._backends.keys())
@classmethod @classmethod
def _is_context_parallel_enabled( def _is_context_parallel_available(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] cls,
backend: AttentionBackendName,
) -> bool: ) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel supports_context_parallel = backend.value in cls._supports_context_parallel
is_degree_greater_than_1 = parallel_config is not None and ( return supports_context_parallel
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
@contextlib.contextmanager @contextlib.contextmanager
...@@ -306,14 +305,6 @@ def dispatch_attention_fn( ...@@ -306,14 +305,6 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend) backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name) backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
backend_name, parallel_config
):
raise ValueError(
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)
kwargs = { kwargs = {
"query": query, "query": query,
"key": key, "key": key,
......
...@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
config: Union[ParallelConfig, ContextParallelConfig], config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
): ):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning( logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
) )
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed must be available and initialized before calling `enable_parallelism`."
)
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from .attention_processor import Attention, MochiAttention
if isinstance(config, ContextParallelConfig): if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config) config = ParallelConfig(context_parallel_config=config)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type) device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count()) device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = None attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
if config.context_parallel_config is not None: if config.context_parallel_config is not None:
cp_config = config.context_parallel_config for module in self.modules():
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: if not isinstance(module, attention_classes):
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") continue
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
)
config.setup(rank, world_size, device, cp_mesh=cp_mesh) processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
if cp_plan is None and self._cp_plan is None: attention_backend = processor._attention_backend
raise ValueError( if attention_backend is None:
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
) else:
cp_plan = cp_plan if cp_plan is not None else self._cp_plan attention_backend = AttentionBackendName(attention_backend)
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
)
# All modules use the same attention processor and backend. We don't need to
# iterate over all modules after checking the first processor
break
mesh = None
if config.context_parallel_config is not None: if config.context_parallel_config is not None:
apply_context_parallel(self, config.context_parallel_config, cp_plan) cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,
)
config.setup(rank, world_size, device, mesh=mesh)
self._parallel_config = config self._parallel_config = config
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules(): for module in self.modules():
if not isinstance(module, attention_classes): if not isinstance(module, attention_classes):
continue continue
...@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
continue continue
processor._parallel_config = config processor._parallel_config = config
if config.context_parallel_config is not None:
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, config.context_parallel_config, cp_plan)
@classmethod @classmethod
def _load_pretrained_model( def _load_pretrained_model(
cls, cls,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment