Unverified Commit 22484a22 authored by fuheaven's avatar fuheaven Committed by GitHub
Browse files

Dcu: format code (#588)

parent 1f7bad54
...@@ -12,6 +12,7 @@ class DcuDevice: ...@@ -12,6 +12,7 @@ class DcuDevice:
DCU uses ROCm which provides CUDA-compatible APIs through HIP. DCU uses ROCm which provides CUDA-compatible APIs through HIP.
Most PyTorch operations work transparently through the ROCm backend. Most PyTorch operations work transparently through the ROCm backend.
""" """
name = "dcu" name = "dcu"
@staticmethod @staticmethod
...@@ -52,4 +53,3 @@ class DcuDevice: ...@@ -52,4 +53,3 @@ class DcuDevice:
# RCCL is compatible with NCCL backend # RCCL is compatible with NCCL backend
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
from .flash_attn import * from .flash_attn import *
...@@ -28,7 +28,6 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -28,7 +28,6 @@ class FlashAttnDcu(AttnWeightTemplate):
- PyTorch: 2.7.1 - PyTorch: 2.7.1
- Python: 3.10 - Python: 3.10
- Flash Attention: 2.6.1 (ROCm) - Flash Attention: 2.6.1 (ROCm)
Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
""" """
...@@ -56,7 +55,6 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -56,7 +55,6 @@ class FlashAttnDcu(AttnWeightTemplate):
): ):
""" """
Execute Flash Attention computation. Execute Flash Attention computation.
Args: Args:
q: [B, Lq, Nq, C1] Query tensor q: [B, Lq, Nq, C1] Query tensor
k: [B, Lk, Nk, C1] Key tensor k: [B, Lk, Nk, C1] Key tensor
...@@ -68,7 +66,6 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -68,7 +66,6 @@ class FlashAttnDcu(AttnWeightTemplate):
causal: Whether to apply causal mask causal: Whether to apply causal mask
window_size: Sliding window size tuple (left, right) window_size: Sliding window size tuple (left, right)
deterministic: Whether to use deterministic algorithm deterministic: Whether to use deterministic algorithm
Returns: Returns:
Output tensor: [B, Lq, Nq, C2] Output tensor: [B, Lq, Nq, C2]
""" """
...@@ -129,14 +126,12 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -129,14 +126,12 @@ class FlashAttnDcu(AttnWeightTemplate):
def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0): def _sdpa_fallback(self, q, k, v, causal=False, dropout_p=0.0):
""" """
Fallback to PyTorch Scaled Dot Product Attention. Fallback to PyTorch Scaled Dot Product Attention.
Args: Args:
q: [B, Lq, Nq, C] Query tensor q: [B, Lq, Nq, C] Query tensor
k: [B, Lk, Nk, C] Key tensor k: [B, Lk, Nk, C] Key tensor
v: [B, Lk, Nk, C] Value tensor v: [B, Lk, Nk, C] Value tensor
causal: Whether to apply causal mask causal: Whether to apply causal mask
dropout_p: Dropout probability dropout_p: Dropout probability
Returns: Returns:
Output tensor: [B, Lq, Nq, C] Output tensor: [B, Lq, Nq, C]
""" """
...@@ -145,10 +140,7 @@ class FlashAttnDcu(AttnWeightTemplate): ...@@ -145,10 +140,7 @@ class FlashAttnDcu(AttnWeightTemplate):
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention( out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p)
q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p
)
# Transpose back to [B, Lq, Nq, C] # Transpose back to [B, Lq, Nq, C]
return out.transpose(1, 2).contiguous() return out.transpose(1, 2).contiguous()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment