"vscode:/vscode.git/clone" did not exist on "5aa31bd6741a3469415bd95c91b0b47b98ba87a8"
Unverified Commit 22484a22 authored by fuheaven's avatar fuheaven Committed by GitHub
Browse files

Dcu: format code (#588)

parent 1f7bad54
...@@ -8,17 +8,18 @@ from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER ...@@ -8,17 +8,18 @@ from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
class DcuDevice: class DcuDevice:
""" """
DCU (AMD GPU) Device implementation for LightX2V. DCU (AMD GPU) Device implementation for LightX2V.
DCU uses ROCm which provides CUDA-compatible APIs through HIP. DCU uses ROCm which provides CUDA-compatible APIs through HIP.
Most PyTorch operations work transparently through the ROCm backend. Most PyTorch operations work transparently through the ROCm backend.
""" """
name = "dcu" name = "dcu"
@staticmethod @staticmethod
def is_available() -> bool: def is_available() -> bool:
""" """
Check if DCU is available. Check if DCU is available.
DCU uses the standard CUDA API through ROCm's HIP compatibility layer. DCU uses the standard CUDA API through ROCm's HIP compatibility layer.
Returns: Returns:
bool: True if DCU/CUDA is available bool: True if DCU/CUDA is available
...@@ -32,10 +33,10 @@ class DcuDevice: ...@@ -32,10 +33,10 @@ class DcuDevice:
def get_device() -> str: def get_device() -> str:
""" """
Get the device type string. Get the device type string.
Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm. Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm.
This allows seamless integration with existing PyTorch code. This allows seamless integration with existing PyTorch code.
Returns: Returns:
str: "cuda" for ROCm compatibility str: "cuda" for ROCm compatibility
""" """
...@@ -45,11 +46,10 @@ class DcuDevice: ...@@ -45,11 +46,10 @@ class DcuDevice:
def init_parallel_env(): def init_parallel_env():
""" """
Initialize distributed parallel environment for DCU. Initialize distributed parallel environment for DCU.
Uses RCCL (ROCm Collective Communications Library) which is Uses RCCL (ROCm Collective Communications Library) which is
compatible with NCCL APIs for multi-GPU communication. compatible with NCCL APIs for multi-GPU communication.
""" """
# RCCL is compatible with NCCL backend # RCCL is compatible with NCCL backend
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
from .flash_attn import * from .flash_attn import *
...@@ -20,15 +20,14 @@ except ImportError: ...@@ -20,15 +20,14 @@ except ImportError:
class FlashAttnDcu(AttnWeightTemplate): class FlashAttnDcu(AttnWeightTemplate):
""" """
DCU Flash Attention implementation. DCU Flash Attention implementation.
Uses AMD ROCm version of Flash Attention 2.6.1 when available. Uses AMD ROCm version of Flash Attention 2.6.1 when available.
Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed. Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.
Tested Environment: Tested Environment:
- PyTorch: 2.7.1 - PyTorch: 2.7.1
- Python: 3.10 - Python: 3.10
- Flash Attention: 2.6.1 (ROCm) - Flash Attention: 2.6.1 (ROCm)
Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
""" """
...@@ -56,7 +55,6 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -56,7 +55,6 @@ class FlashAttnDcu(AttnWeightTemplate):
): ):
""" """
Execute Flash Attention computation. Execute Flash Attention computation.
Args: Args:
q: [B, Lq, Nq, C1] Query tensor q: [B, Lq, Nq, C1] Query tensor
k: [B, Lk, Nk, C1] Key tensor k: [B, Lk, Nk, C1] Key tensor
...@@ -68,7 +66,6 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -68,7 +66,6 @@ class FlashAttnDcu(AttnWeightTemplate):
causal: Whether to apply causal mask causal: Whether to apply causal mask
window_size: Sliding window size tuple (left, right) window_size: Sliding window size tuple (left, right)
deterministic: Whether to use deterministic algorithm deterministic: Whether to use deterministic algorithm
Returns: Returns:
Output tensor: [B, Lq, Nq, C2] Output tensor: [B, Lq, Nq, C2]
""" """
...@@ -129,14 +126,12 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -129,14 +126,12 @@ class FlashAttnDcu(AttnWeightTemplate):
def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0): def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
""" """
Fallback to PyTorch Scaled Dot Product Attention. Fallback to PyTorch Scaled Dot Product Attention.
Args: Args:
q: [B, Lq, Nq, C] Query tensor q: [B, Lq, Nq, C] Query tensor
k: [B, Lk, Nk, C] Key tensor k: [B, Lk, Nk, C] Key tensor
v: [B, Lk, Nk, C] Value tensor v: [B, Lk, Nk, C] Value tensor
causal: Whether to apply causal mask causal: Whether to apply causal mask
dropout_p: Dropout probability dropout_p: Dropout probability
Returns: Returns:
Output tensor: [B, Lq, Nq, C] Output tensor: [B, Lq, Nq, C]
""" """
...@@ -145,10 +140,7 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -145,10 +140,7 @@ class FlashAttnDcu(AttnWeightTemplate):
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p)
q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p
)
# Transpose back to [B, Lq, Nq, C] # Transpose back to [B, Lq, Nq, C]
return out.transpose(1, 2).contiguous() return out.transpose(1, 2).contiguous()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment