Unverified Commit 22484a22 authored by fuheaven's avatar fuheaven Committed by GitHub
Browse files

Dcu: format code (#588)

parent 1f7bad54
......@@ -8,17 +8,18 @@ from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
class DcuDevice:
"""
DCU (AMD GPU) Device implementation for LightX2V.
DCU uses ROCm which provides CUDA-compatible APIs through HIP.
Most PyTorch operations work transparently through the ROCm backend.
"""
name = "dcu"
@staticmethod
def is_available() -> bool:
"""
Check if DCU is available.
DCU uses the standard CUDA API through ROCm's HIP compatibility layer.
Returns:
bool: True if DCU/CUDA is available
......@@ -32,10 +33,10 @@ class DcuDevice:
def get_device() -> str:
"""
Get the device type string.
Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm.
This allows seamless integration with existing PyTorch code.
Returns:
str: "cuda" for ROCm compatibility
"""
......@@ -45,11 +46,10 @@ class DcuDevice:
def init_parallel_env():
"""
Initialize distributed parallel environment for DCU.
Uses RCCL (ROCm Collective Communications Library) which is
compatible with NCCL APIs for multi-GPU communication.
"""
# RCCL is compatible with NCCL backend
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
......@@ -20,15 +20,14 @@ except ImportError:
class FlashAttnDcu(AttnWeightTemplate):
"""
DCU Flash Attention implementation.
Uses AMD ROCm version of Flash Attention 2.6.1 when available.
Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.
Tested Environment:
- PyTorch: 2.7.1
- Python: 3.10
- Flash Attention: 2.6.1 (ROCm)
Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
"""
......@@ -56,7 +55,6 @@ class FlashAttnDcu(AttnWeightTemplate):
):
"""
Execute Flash Attention computation.
Args:
q: [B, Lq, Nq, C1] Query tensor
k: [B, Lk, Nk, C1] Key tensor
......@@ -68,7 +66,6 @@ class FlashAttnDcu(AttnWeightTemplate):
causal: Whether to apply causal mask
window_size: Sliding window size tuple (left, right)
deterministic: Whether to use deterministic algorithm
Returns:
Output tensor: [B, Lq, Nq, C2]
"""
......@@ -129,14 +126,12 @@ class FlashAttnDcu(AttnWeightTemplate):
def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
"""
Fallback to PyTorch Scaled Dot Product Attention.
Args:
q: [B, Lq, Nq, C] Query tensor
k: [B, Lk, Nk, C] Key tensor
v: [B, Lk, Nk, C] Value tensor
causal: Whether to apply causal mask
dropout_p: Dropout probability
Returns:
Output tensor: [B, Lq, Nq, C]
"""
......@@ -145,10 +140,7 @@ class FlashAttnDcu(AttnWeightTemplate):
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p)
# Transpose back to [B, Lq, Nq, C]
return out.transpose(1, 2).contiguous()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment