Unverified Commit 1f7bad54 authored by fuheaven's avatar fuheaven Committed by GitHub
Browse files

add dcu platform (#584)

parent 5546f759
from lightx2v_platform.base.base import check_ai_device, init_ai_device
from lightx2v_platform.base.cambricon_mlu import MluDevice
from lightx2v_platform.base.dcu import DcuDevice
from lightx2v_platform.base.metax import MetaxDevice
from lightx2v_platform.base.nvidia import CudaDevice
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice"]
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "DcuDevice"]
import torch
import torch.distributed as dist
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
@PLATFORM_DEVICE_REGISTER("dcu")
class DcuDevice:
"""
DCU (AMD GPU) Device implementation for LightX2V.
DCU uses ROCm which provides CUDA-compatible APIs through HIP.
Most PyTorch operations work transparently through the ROCm backend.
"""
name = "dcu"
@staticmethod
def is_available() -> bool:
"""
Check if DCU is available.
DCU uses the standard CUDA API through ROCm's HIP compatibility layer.
Returns:
bool: True if DCU/CUDA is available
"""
try:
return torch.cuda.is_available()
except ImportError:
return False
@staticmethod
def get_device() -> str:
"""
Get the device type string.
Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm.
This allows seamless integration with existing PyTorch code.
Returns:
str: "cuda" for ROCm compatibility
"""
return "cuda"
@staticmethod
def init_parallel_env():
"""
Initialize distributed parallel environment for DCU.
Uses RCCL (ROCm Collective Communications Library) which is
compatible with NCCL APIs for multi-GPU communication.
"""
# RCCL is compatible with NCCL backend
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
import os
from lightx2v_platform.base.global_var import AI_DEVICE
if AI_DEVICE == "mlu":
from .attn.cambricon_mlu import *
from .mm.cambricon_mlu import *
elif AI_DEVICE == "cuda":
# Check if running on DCU platform
if os.getenv("PLATFORM") == "dcu":
from .attn.dcu import *
from .mm.dcu import *
import torch
from loguru import logger
from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER
# Try to import Flash Attention (ROCm version 2.6.1)
try:
from flash_attn import flash_attn_varlen_func
FLASH_ATTN_AVAILABLE = True
logger.info(f"Flash Attention (ROCm) is available")
except ImportError:
logger.warning("Flash Attention not found. Will use PyTorch SDPA as fallback.")
flash_attn_varlen_func = None
FLASH_ATTN_AVAILABLE = False
@PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_dcu")
class FlashAttnDcu(AttnWeightTemplate):
"""
DCU Flash Attention implementation.
Uses AMD ROCm version of Flash Attention 2.6.1 when available.
Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.
Tested Environment:
- PyTorch: 2.7.1
- Python: 3.10
- Flash Attention: 2.6.1 (ROCm)
Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
"""
def __init__(self, weight_name="flash_attn_dcu"):
super().__init__(weight_name)
self.use_flash_attn = FLASH_ATTN_AVAILABLE
if self.use_flash_attn:
logger.info("Flash Attention 2.6.1 (ROCm) is available and will be used.")
else:
logger.warning("Flash Attention not available. Using PyTorch SDPA fallback.")
def apply(
self,
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
):
"""
Execute Flash Attention computation.
Args:
q: [B, Lq, Nq, C1] Query tensor
k: [B, Lk, Nk, C1] Key tensor
v: [B, Lk, Nk, C2] Value tensor
q_lens: [B] Optional sequence lengths for queries
k_lens: [B] Optional sequence lengths for keys
dropout_p: Dropout probability
softmax_scale: Scaling factor for QK^T before softmax
causal: Whether to apply causal mask
window_size: Sliding window size tuple (left, right)
deterministic: Whether to use deterministic algorithm
Returns:
Output tensor: [B, Lq, Nq, C2]
"""
if not self.use_flash_attn:
# Fallback to PyTorch SDPA
return self._sdpa_fallback(q, k, v, causal, dropout_p)
# Ensure data types are half precision
half_dtypes = (torch.float16, torch.bfloat16)
dtype = q.dtype if q.dtype in half_dtypes else torch.bfloat16
out_dtype = q.dtype
b, lq, lk = q.size(0), q.size(1), k.size(1)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# Preprocess query
if q_lens is None:
q_flat = half(q.flatten(0, 1))
q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
else:
q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# Preprocess key/value
if k_lens is None:
k_flat = half(k.flatten(0, 1))
v_flat = half(v.flatten(0, 1))
k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device)
else:
k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
# Compute cumulative sequence lengths
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
# Use Flash Attention 2.6.1 (ROCm version)
output = flash_attn_varlen_func(
q=q_flat,
k=k_flat,
v=v_flat,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
)
# Reshape back to batch dimension
output = output.unflatten(0, (b, lq))
return output.to(out_dtype)
def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
"""
Fallback to PyTorch Scaled Dot Product Attention.
Args:
q: [B, Lq, Nq, C] Query tensor
k: [B, Lk, Nk, C] Key tensor
v: [B, Lk, Nk, C] Value tensor
causal: Whether to apply causal mask
dropout_p: Dropout probability
Returns:
Output tensor: [B, Lq, Nq, C]
"""
# Transpose to [B, Nq, Lq, C] for SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p
)
# Transpose back to [B, Lq, Nq, C]
return out.transpose(1, 2).contiguous()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment