Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
332 additions
and
0 deletions
+332
-0
lightx2v_platform/ops/attn/__init__.py
lightx2v_platform/ops/attn/__init__.py
+0
-0
lightx2v_platform/ops/attn/cambricon_mlu/__init__.py
lightx2v_platform/ops/attn/cambricon_mlu/__init__.py
+2
-0
lightx2v_platform/ops/attn/cambricon_mlu/flash_attn.py
lightx2v_platform/ops/attn/cambricon_mlu/flash_attn.py
+42
-0
lightx2v_platform/ops/attn/cambricon_mlu/sage_attn.py
lightx2v_platform/ops/attn/cambricon_mlu/sage_attn.py
+31
-0
lightx2v_platform/ops/attn/template.py
lightx2v_platform/ops/attn/template.py
+32
-0
lightx2v_platform/ops/mm/__init__.py
lightx2v_platform/ops/mm/__init__.py
+0
-0
lightx2v_platform/ops/mm/cambricon_mlu/__init__.py
lightx2v_platform/ops/mm/cambricon_mlu/__init__.py
+1
-0
lightx2v_platform/ops/mm/cambricon_mlu/mm_weight.py
lightx2v_platform/ops/mm/cambricon_mlu/mm_weight.py
+37
-0
lightx2v_platform/ops/mm/cambricon_mlu/q_linear.py
lightx2v_platform/ops/mm/cambricon_mlu/q_linear.py
+47
-0
lightx2v_platform/ops/mm/template.py
lightx2v_platform/ops/mm/template.py
+56
-0
lightx2v_platform/ops/norm/__init__.py
lightx2v_platform/ops/norm/__init__.py
+0
-0
lightx2v_platform/ops/rope/__init__.py
lightx2v_platform/ops/rope/__init__.py
+0
-0
lightx2v_platform/registry_factory.py
lightx2v_platform/registry_factory.py
+58
-0
lightx2v_platform/set_ai_device.py
lightx2v_platform/set_ai_device.py
+15
-0
lightx2v_platform/test/test_device.py
lightx2v_platform/test/test_device.py
+11
-0
No files found.
lightx2v_platform/ops/attn/__init__.py
0 → 100755
View file @
b50498fa
lightx2v_platform/ops/attn/cambricon_mlu/__init__.py
0 → 100755
View file @
b50498fa
from
.flash_attn
import
*
from
.sage_attn
import
*
lightx2v_platform/ops/attn/cambricon_mlu/flash_attn.py
0 → 100644
View file @
b50498fa
import
math
from
lightx2v_platform.ops.attn.template
import
AttnWeightTemplate
from
lightx2v_platform.registry_factory
import
PLATFORM_ATTN_WEIGHT_REGISTER
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
@
PLATFORM_ATTN_WEIGHT_REGISTER
(
"mlu_flash_attn"
)
class
MluFlashAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
assert
tmo
is
not
None
,
"torch_mlu_ops is not installed."
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
flash_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
cu_seqlens_q
,
cu_seq_lens_kv
=
cu_seqlens_kv
,
max_seq_len_q
=
max_seqlen_q
,
max_seq_len_kv
=
max_seqlen_kv
,
softmax_scale
=
softmax_scale
,
return_lse
=
False
,
out_dtype
=
q
.
dtype
,
is_causal
=
False
,
out
=
None
,
alibi_slope
=
None
,
attn_bias
=
None
,
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v_platform/ops/attn/cambricon_mlu/sage_attn.py
0 → 100755
View file @
b50498fa
import
math
import
torch
from
lightx2v_platform.ops.attn.template
import
AttnWeightTemplate
from
lightx2v_platform.registry_factory
import
PLATFORM_ATTN_WEIGHT_REGISTER
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
@
PLATFORM_ATTN_WEIGHT_REGISTER
(
"mlu_sage_attn"
)
class
MluSageAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
assert
tmo
is
not
None
,
"torch_mlu_ops is not installed."
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
softmax_scale
=
1
/
math
.
sqrt
(
q
.
shape
[
-
1
])
x
=
tmo
.
sage_attn
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seq_lens_q
=
None
,
cu_seq_lens_kv
=
None
,
max_seq_len_kv
=
max_seqlen_kv
,
max_seq_len_q
=
max_seqlen_q
,
is_causal
=
False
,
compute_dtype
=
torch
.
bfloat16
,
softmax_scale
=
softmax_scale
)
x
=
x
.
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v_platform/ops/attn/template.py
0 → 100755
View file @
b50498fa
from
abc
import
ABCMeta
,
abstractmethod
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
):
self
.
weight_name
=
weight_name
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cpu
(
self
,
non_blocking
=
False
):
pass
def
to_cuda
(
self
,
non_blocking
=
False
):
pass
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_inde
=
None
):
return
{}
lightx2v_platform/ops/mm/__init__.py
0 → 100755
View file @
b50498fa
lightx2v_platform/ops/mm/cambricon_mlu/__init__.py
0 → 100755
View file @
b50498fa
from
.mm_weight
import
*
lightx2v_platform/ops/mm/cambricon_mlu/mm_weight.py
0 → 100644
View file @
b50498fa
from
lightx2v_platform.ops.mm.template
import
MMWeightQuantTemplate
from
lightx2v_platform.registry_factory
import
PLATFORM_MM_WEIGHT_REGISTER
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
@
PLATFORM_MM_WEIGHT_REGISTER
(
"int8-tmo"
)
class
MMWeightWint8channelAint8channeldynamicMlu
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_tmo
def
act_quant_int8_perchannel_sym_tmo
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
apply
(
self
,
input_tensor
):
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
.
contiguous
(),
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
bias
=
self
.
bias
if
self
.
bias
is
not
None
else
None
,
output_dtype
=
dtype
,
use_hp_active
=
True
)
return
output_tensor
lightx2v_platform/ops/mm/cambricon_mlu/q_linear.py
0 → 100755
View file @
b50498fa
import
torch
import
torch.nn
as
nn
try
:
import
torch_mlu_ops
as
tmo
except
ImportError
:
tmo
=
None
class
MluQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
int8
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
self
.
register_buffer
(
"bias"
,
None
)
def
act_quant_func
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
tmo
.
scaled_quantize
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
dtype
=
input_tensor
.
dtype
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
tmo
.
scaled_matmul
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
.
squeeze
(
-
1
),
output_dtype
=
dtype
)
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
lightx2v_platform/ops/mm/template.py
0 → 100644
View file @
b50498fa
from
abc
import
ABCMeta
,
abstractmethod
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
):
pass
def
set_config
(
self
,
config
=
{}):
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
if
hasattr
(
self
,
"weight_scale_name"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
copy_
(
self
.
weight_scale
,
non_blocking
=
non_blocking
).
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale"
):
self
.
weight_scale
=
self
.
weight_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
class
MMWeightQuantTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
weight_scale_name
=
self
.
weight_name
.
removesuffix
(
".weight"
)
+
".weight_scale"
self
.
load_func
=
None
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
None
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
lightx2v_platform/ops/norm/__init__.py
0 → 100755
View file @
b50498fa
lightx2v_platform/ops/rope/__init__.py
0 → 100755
View file @
b50498fa
lightx2v_platform/registry_factory.py
0 → 100755
View file @
b50498fa
class
Register
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Register
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_dict
=
{}
def
__call__
(
self
,
target_or_name
):
if
callable
(
target_or_name
):
return
self
.
register
(
target_or_name
)
else
:
return
lambda
x
:
self
.
register
(
x
,
key
=
target_or_name
)
def
register
(
self
,
target
,
key
=
None
):
if
not
callable
(
target
):
raise
Exception
(
f
"Error:
{
target
}
must be callable!"
)
if
key
is
None
:
key
=
target
.
__name__
if
key
in
self
.
_dict
:
raise
Exception
(
f
"
{
key
}
already exists."
)
self
[
key
]
=
target
return
target
def
__setitem__
(
self
,
key
,
value
):
self
.
_dict
[
key
]
=
value
def
__getitem__
(
self
,
key
):
return
self
.
_dict
[
key
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
_dict
def
__str__
(
self
):
return
str
(
self
.
_dict
)
def
keys
(
self
):
return
self
.
_dict
.
keys
()
def
values
(
self
):
return
self
.
_dict
.
values
()
def
items
(
self
):
return
self
.
_dict
.
items
()
def
get
(
self
,
key
,
default
=
None
):
return
self
.
_dict
.
get
(
key
,
default
)
def
merge
(
self
,
other_register
):
for
key
,
value
in
other_register
.
items
():
if
key
in
self
.
_dict
:
raise
Exception
(
f
"
{
key
}
already exists in target register."
)
self
[
key
]
=
value
PLATFORM_DEVICE_REGISTER
=
Register
()
PLATFORM_ATTN_WEIGHT_REGISTER
=
Register
()
PLATFORM_MM_WEIGHT_REGISTER
=
Register
()
lightx2v_platform/set_ai_device.py
0 → 100644
View file @
b50498fa
import
os
from
lightx2v_platform
import
*
def
set_ai_device
():
platform
=
os
.
getenv
(
"PLATFORM"
,
"cuda"
)
init_ai_device
(
platform
)
from
lightx2v_platform.base.global_var
import
AI_DEVICE
check_ai_device
(
AI_DEVICE
)
set_ai_device
()
from
lightx2v_platform.ops
import
*
# noqa: E402
lightx2v_platform/test/test_device.py
0 → 100644
View file @
b50498fa
import
os
from
lightx2v_platform
import
*
init_ai_device
(
os
.
getenv
(
"AI_DEVICE"
,
"cuda"
))
from
lightx2v_platform.base.global_var
import
AI_DEVICE
# noqa E402
if
__name__
==
"__main__"
:
print
(
f
"AI_DEVICE :
{
AI_DEVICE
}
"
)
is_available
=
check_ai_device
(
AI_DEVICE
)
print
(
f
"Device available:
{
is_available
}
"
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment