Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
1f7bad54
Unverified
Commit
1f7bad54
authored
Dec 09, 2025
by
fuheaven
Committed by
GitHub
Dec 09, 2025
Browse files
add dcu platform (#584)
parent
5546f759
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
220 additions
and
1 deletion
+220
-1
lightx2v_platform/base/__init__.py
lightx2v_platform/base/__init__.py
+2
-1
lightx2v_platform/base/dcu.py
lightx2v_platform/base/dcu.py
+55
-0
lightx2v_platform/ops/__init__.py
lightx2v_platform/ops/__init__.py
+7
-0
lightx2v_platform/ops/attn/dcu/__init__.py
lightx2v_platform/ops/attn/dcu/__init__.py
+2
-0
lightx2v_platform/ops/attn/dcu/flash_attn.py
lightx2v_platform/ops/attn/dcu/flash_attn.py
+154
-0
No files found.
lightx2v_platform/base/__init__.py
View file @
1f7bad54
from
lightx2v_platform.base.base
import
check_ai_device
,
init_ai_device
from
lightx2v_platform.base.cambricon_mlu
import
MluDevice
from
lightx2v_platform.base.dcu
import
DcuDevice
from
lightx2v_platform.base.metax
import
MetaxDevice
from
lightx2v_platform.base.nvidia
import
CudaDevice
__all__
=
[
"init_ai_device"
,
"check_ai_device"
,
"CudaDevice"
,
"MluDevice"
,
"MetaxDevice"
]
__all__
=
[
"init_ai_device"
,
"check_ai_device"
,
"CudaDevice"
,
"MluDevice"
,
"MetaxDevice"
,
"DcuDevice"
]
lightx2v_platform/base/dcu.py
0 → 100644
View file @
1f7bad54
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"dcu"
)
class
DcuDevice
:
"""
DCU (AMD GPU) Device implementation for LightX2V.
DCU uses ROCm which provides CUDA-compatible APIs through HIP.
Most PyTorch operations work transparently through the ROCm backend.
"""
name
=
"dcu"
@
staticmethod
def
is_available
()
->
bool
:
"""
Check if DCU is available.
DCU uses the standard CUDA API through ROCm's HIP compatibility layer.
Returns:
bool: True if DCU/CUDA is available
"""
try
:
return
torch
.
cuda
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
"""
Get the device type string.
Returns "cuda" because DCU uses CUDA-compatible APIs through ROCm.
This allows seamless integration with existing PyTorch code.
Returns:
str: "cuda" for ROCm compatibility
"""
return
"cuda"
@
staticmethod
def
init_parallel_env
():
"""
Initialize distributed parallel environment for DCU.
Uses RCCL (ROCm Collective Communications Library) which is
compatible with NCCL APIs for multi-GPU communication.
"""
# RCCL is compatible with NCCL backend
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/ops/__init__.py
View file @
1f7bad54
import
os
from
lightx2v_platform.base.global_var
import
AI_DEVICE
if
AI_DEVICE
==
"mlu"
:
from
.attn.cambricon_mlu
import
*
from
.mm.cambricon_mlu
import
*
elif
AI_DEVICE
==
"cuda"
:
# Check if running on DCU platform
if
os
.
getenv
(
"PLATFORM"
)
==
"dcu"
:
from
.attn.dcu
import
*
from
.mm.dcu
import
*
lightx2v_platform/ops/attn/dcu/__init__.py
0 → 100644
View file @
1f7bad54
from
.flash_attn
import
*
lightx2v_platform/ops/attn/dcu/flash_attn.py
0 → 100644
View file @
1f7bad54
import
torch
from
loguru
import
logger
from
lightx2v_platform.ops.attn.template
import
AttnWeightTemplate
from
lightx2v_platform.registry_factory
import
PLATFORM_ATTN_WEIGHT_REGISTER
# Try to import Flash Attention (ROCm version 2.6.1)
try
:
from
flash_attn
import
flash_attn_varlen_func
FLASH_ATTN_AVAILABLE
=
True
logger
.
info
(
f
"Flash Attention (ROCm) is available"
)
except
ImportError
:
logger
.
warning
(
"Flash Attention not found. Will use PyTorch SDPA as fallback."
)
flash_attn_varlen_func
=
None
FLASH_ATTN_AVAILABLE
=
False
@
PLATFORM_ATTN_WEIGHT_REGISTER
(
"flash_attn_dcu"
)
class
FlashAttnDcu
(
AttnWeightTemplate
):
"""
DCU Flash Attention implementation.
Uses AMD ROCm version of Flash Attention 2.6.1 when available.
Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.
Tested Environment:
- PyTorch: 2.7.1
- Python: 3.10
- Flash Attention: 2.6.1 (ROCm)
Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
"""
def
__init__
(
self
,
weight_name
=
"flash_attn_dcu"
):
super
().
__init__
(
weight_name
)
self
.
use_flash_attn
=
FLASH_ATTN_AVAILABLE
if
self
.
use_flash_attn
:
logger
.
info
(
"Flash Attention 2.6.1 (ROCm) is available and will be used."
)
else
:
logger
.
warning
(
"Flash Attention not available. Using PyTorch SDPA fallback."
)
def
apply
(
self
,
q
,
k
,
v
,
q_lens
=
None
,
k_lens
=
None
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
deterministic
=
False
,
):
"""
Execute Flash Attention computation.
Args:
q: [B, Lq, Nq, C1] Query tensor
k: [B, Lk, Nk, C1] Key tensor
v: [B, Lk, Nk, C2] Value tensor
q_lens: [B] Optional sequence lengths for queries
k_lens: [B] Optional sequence lengths for keys
dropout_p: Dropout probability
softmax_scale: Scaling factor for QK^T before softmax
causal: Whether to apply causal mask
window_size: Sliding window size tuple (left, right)
deterministic: Whether to use deterministic algorithm
Returns:
Output tensor: [B, Lq, Nq, C2]
"""
if
not
self
.
use_flash_attn
:
# Fallback to PyTorch SDPA
return
self
.
_sdpa_fallback
(
q
,
k
,
v
,
causal
,
dropout_p
)
# Ensure data types are half precision
half_dtypes
=
(
torch
.
float16
,
torch
.
bfloat16
)
dtype
=
q
.
dtype
if
q
.
dtype
in
half_dtypes
else
torch
.
bfloat16
out_dtype
=
q
.
dtype
b
,
lq
,
lk
=
q
.
size
(
0
),
q
.
size
(
1
),
k
.
size
(
1
)
def
half
(
x
):
return
x
if
x
.
dtype
in
half_dtypes
else
x
.
to
(
dtype
)
# Preprocess query
if
q_lens
is
None
:
q_flat
=
half
(
q
.
flatten
(
0
,
1
))
q_lens
=
torch
.
tensor
([
lq
]
*
b
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
else
:
q_flat
=
half
(
torch
.
cat
([
u
[:
v
]
for
u
,
v
in
zip
(
q
,
q_lens
)]))
# Preprocess key/value
if
k_lens
is
None
:
k_flat
=
half
(
k
.
flatten
(
0
,
1
))
v_flat
=
half
(
v
.
flatten
(
0
,
1
))
k_lens
=
torch
.
tensor
([
lk
]
*
b
,
dtype
=
torch
.
int32
,
device
=
k
.
device
)
else
:
k_flat
=
half
(
torch
.
cat
([
u
[:
v
]
for
u
,
v
in
zip
(
k
,
k_lens
)]))
v_flat
=
half
(
torch
.
cat
([
u
[:
v
]
for
u
,
v
in
zip
(
v
,
k_lens
)]))
# Compute cumulative sequence lengths
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
# Use Flash Attention 2.6.1 (ROCm version)
output
=
flash_attn_varlen_func
(
q
=
q_flat
,
k
=
k_flat
,
v
=
v_flat
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_k
=
lk
,
dropout_p
=
dropout_p
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
deterministic
=
deterministic
,
)
# Reshape back to batch dimension
output
=
output
.
unflatten
(
0
,
(
b
,
lq
))
return
output
.
to
(
out_dtype
)
def
_sdpa_fallback
(
self
,
q
,
k
,
v
,
causal
=
False
,
dropout_p
=
0.0
):
"""
Fallback to PyTorch Scaled Dot Product Attention.
Args:
q: [B, Lq, Nq, C] Query tensor
k: [B, Lk, Nk, C] Key tensor
v: [B, Lk, Nk, C] Value tensor
causal: Whether to apply causal mask
dropout_p: Dropout probability
Returns:
Output tensor: [B, Lq, Nq, C]
"""
# Transpose to [B, Nq, Lq, C] for SDPA
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
is_causal
=
causal
,
dropout_p
=
dropout_p
)
# Transpose back to [B, Lq, Nq, C]
return
out
.
transpose
(
1
,
2
).
contiguous
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment