Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
64 deletions
+183
-64
lightx2v/models/schedulers/qwen_image/scheduler.py
lightx2v/models/schedulers/qwen_image/scheduler.py
+6
-6
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+4
-3
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+6
-4
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+5
-5
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+7
-7
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+2
-1
lightx2v/models/video_encoders/hf/qwen_image/vae.py
lightx2v/models/video_encoders/hf/qwen_image/vae.py
+3
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+10
-13
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+7
-15
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+15
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+3
-3
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+6
-6
lightx2v_platform/__init__.py
lightx2v_platform/__init__.py
+1
-0
lightx2v_platform/base/__init__.py
lightx2v_platform/base/__init__.py
+6
-0
lightx2v_platform/base/base.py
lightx2v_platform/base/base.py
+26
-0
lightx2v_platform/base/cambricon_mlu.py
lightx2v_platform/base/cambricon_mlu.py
+27
-0
lightx2v_platform/base/global_var.py
lightx2v_platform/base/global_var.py
+1
-0
lightx2v_platform/base/metax.py
lightx2v_platform/base/metax.py
+7
-0
lightx2v_platform/base/nvidia.py
lightx2v_platform/base/nvidia.py
+36
-0
lightx2v_platform/ops/__init__.py
lightx2v_platform/ops/__init__.py
+5
-0
No files found.
lightx2v/models/schedulers/qwen_image/scheduler.py
View file @
b50498fa
...
...
@@ -8,6 +8,7 @@ import torch
from
diffusers.schedulers.scheduling_flow_match_euler_discrete
import
FlowMatchEulerDiscreteScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
calculate_shift
(
...
...
@@ -133,7 +134,6 @@ class QwenImageScheduler(BaseScheduler):
self
.
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
))
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
,
"scheduler_config.json"
),
"r"
)
as
f
:
self
.
scheduler_config
=
json
.
load
(
f
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
guidance_scale
=
1.0
...
...
@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
shape
=
input_info
.
target_shape
width
,
height
=
shape
[
-
1
],
shape
[
-
2
]
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
latents
=
self
.
_pack_latents
(
latents
,
self
.
config
[
"batchsize"
],
self
.
config
[
"num_channels_latents"
],
height
,
width
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
self
.
run_device
,
self
.
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
AI_DEVICE
,
self
.
dtype
)
self
.
latents
=
latents
self
.
latent_image_ids
=
latent_image_ids
...
...
@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
self
.
run_device
,
AI_DEVICE
,
sigmas
=
sigmas
,
mu
=
mu
,
)
...
...
@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
def
prepare_guidance
(
self
):
# handle guidance
if
self
.
config
[
"guidance_embeds"
]:
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
self
.
run_device
,
dtype
=
torch
.
float32
)
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
AI_DEVICE
,
dtype
=
torch
.
float32
)
guidance
=
guidance
.
expand
(
self
.
latents
.
shape
[
0
])
else
:
guidance
=
None
...
...
@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if
self
.
config
[
"task"
]
==
"i2i"
:
self
.
generator
=
torch
.
Generator
().
manual_seed
(
input_info
.
seed
)
elif
self
.
config
[
"task"
]
==
"t2i"
:
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
input_info
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
input_info
.
seed
)
self
.
prepare_latents
(
input_info
)
self
.
prepare_guidance
()
self
.
set_timesteps
()
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
b50498fa
...
...
@@ -7,6 +7,7 @@ from loguru import logger
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
masks_like
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
EulerScheduler
(
WanScheduler
):
...
...
@@ -58,14 +59,14 @@ class EulerScheduler(WanScheduler):
)
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
...
...
@@ -77,7 +78,7 @@ class EulerScheduler(WanScheduler):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
run_device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
AI_DEVICE
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
b50498fa
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanScheduler4ChangingResolutionInterface
:
def
__new__
(
cls
,
father_scheduler
,
config
):
...
...
@@ -20,7 +22,7 @@ class WanScheduler4ChangingResolution:
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
...
...
@@ -30,7 +32,7 @@ class WanScheduler4ChangingResolution:
int
(
latent_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
)
)
...
...
@@ -43,7 +45,7 @@ class WanScheduler4ChangingResolution:
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
)
)
...
...
@@ -83,7 +85,7 @@ class WanScheduler4ChangingResolution:
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigma
=
self
.
sigmas
[
self
.
step_index
]
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
b50498fa
...
...
@@ -7,12 +7,12 @@ from torch.nn import functional as F
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.utils
import
masks_like
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
...
@@ -36,7 +36,7 @@ class WanScheduler(BaseScheduler):
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
).
to
(
torch
.
device
(
AI_DEVICE
))
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
...
...
@@ -70,7 +70,7 @@ class WanScheduler(BaseScheduler):
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
...
...
@@ -114,14 +114,14 @@ class WanScheduler(BaseScheduler):
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
b50498fa
...
...
@@ -2,12 +2,12 @@ import torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanSFScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
...
...
@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
self
.
context_noise
=
0
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
timesteps
=
[]
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
frame_steps
=
[]
for
step_index
,
current_timestep
in
enumerate
(
self
.
denoising_step_list
):
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
self
.
run_device
,
dtype
=
torch
.
int64
)
*
current_timestep
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
AI_DEVICE
,
dtype
=
torch
.
int64
)
*
current_timestep
frame_steps
.
append
(
timestep
)
timesteps
.
append
(
frame_steps
)
self
.
timesteps
=
timesteps
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
if
self
.
extra_one_step
:
...
...
@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
self
.
sigmas_sf
=
self
.
sf_shift
*
self
.
sigmas_sf
/
(
1
+
(
self
.
sf_shift
-
1
)
*
self
.
sigmas_sf
)
if
self
.
reverse_sigmas
:
self
.
sigmas_sf
=
1
-
self
.
sigmas_sf
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
self
.
run_device
)
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
AI_DEVICE
)
self
.
timesteps_sf
=
self
.
sigmas_sf
*
self
.
num_train_timesteps
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
self
.
run_device
)
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
AI_DEVICE
)
self
.
stream_output
=
None
...
...
@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
# add noise
if
self
.
step_index
<
self
.
infer_steps
-
1
:
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
run_device
,
dtype
=
torch
.
long
)
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
AI_DEVICE
,
dtype
=
torch
.
long
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
noise_next
=
torch
.
randn_like
(
x0_pred
)
...
...
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
b50498fa
...
...
@@ -4,6 +4,7 @@ from typing import Union
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanStepDistillScheduler
(
WanScheduler
):
...
...
@@ -19,7 +20,7 @@ class WanStepDistillScheduler(WanScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
set_denoising_timesteps
(
device
=
self
.
run_device
)
self
.
set_denoising_timesteps
(
device
=
AI_DEVICE
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
...
...
lightx2v/models/video_encoders/hf/qwen_image/vae.py
View file @
b50498fa
...
...
@@ -5,6 +5,8 @@ from typing import Optional
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
from
diffusers
import
AutoencoderKLQwenImage
from
diffusers.image_processor
import
VaeImageProcessor
...
...
@@ -33,7 +35,7 @@ class AutoencoderKLQwenImageVAE:
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
latent_channels
=
config
[
"vae_z_dim"
]
self
.
load
()
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
b50498fa
...
...
@@ -8,6 +8,10 @@ from einops import rearrange
from
loguru
import
logger
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
__all__
=
[
"WanVAE"
,
...
...
@@ -821,11 +825,9 @@ class WanVAE:
use_2d_split
=
True
,
load_from_rank0
=
False
,
use_lightvae
=
False
,
run_device
=
torch
.
device
(
"cuda"
),
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_device
=
run_device
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
self
.
cpu_offload
=
cpu_offload
...
...
@@ -955,11 +957,11 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
to_cuda
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
self
.
run_device
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
self
.
run_device
)
self
.
model
=
self
.
model
.
to
(
self
.
run_device
)
self
.
mean
=
self
.
mean
.
cuda
(
)
self
.
inv_std
=
self
.
inv_std
.
cuda
(
)
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
AI_DEVICE
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
AI_DEVICE
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
self
.
mean
=
self
.
mean
.
to
(
AI_DEVICE
)
self
.
inv_std
=
self
.
inv_std
.
to
(
AI_DEVICE
)
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
...
...
@@ -1330,9 +1332,4 @@ class WanVAE:
def
device_synchronize
(
self
,
):
if
"cuda"
in
str
(
self
.
run_device
):
torch
.
cuda
.
synchronize
()
elif
"mlu"
in
str
(
self
.
run_device
):
torch
.
mlu
.
synchronize
()
elif
"npu"
in
str
(
self
.
run_device
):
torch
.
npu
.
synchronize
()
torch_device_module
.
synchronize
()
lightx2v/utils/profiler.py
View file @
b50498fa
...
...
@@ -7,6 +7,9 @@ import torch.distributed as dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
_ProfilingContext
:
...
...
@@ -27,12 +30,12 @@ class _ProfilingContext:
self
.
metrics_labels
=
metrics_labels
def
__enter__
(
self
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
...
...
@@ -44,12 +47,12 @@ class _ProfilingContext:
return
False
async
def
__aenter__
(
self
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
...
...
@@ -78,17 +81,6 @@ class _ProfilingContext:
return
sync_wrapper
def
device_synchronize
(
self
,
):
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
synchronize
()
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
synchronize
()
return
class
_NullContext
:
# Context manager without decision branch logic overhead
...
...
lightx2v/utils/registry_factory.py
View file @
b50498fa
from
lightx2v_platform.registry_factory
import
PLATFORM_ATTN_WEIGHT_REGISTER
,
PLATFORM_MM_WEIGHT_REGISTER
class
Register
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Register
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -43,6 +46,15 @@ class Register(dict):
def
items
(
self
):
return
self
.
_dict
.
items
()
def
get
(
self
,
key
,
default
=
None
):
return
self
.
_dict
.
get
(
key
,
default
)
def
merge
(
self
,
other_register
):
for
key
,
value
in
other_register
.
items
():
if
key
in
self
.
_dict
:
raise
Exception
(
f
"
{
key
}
already exists in target register."
)
self
[
key
]
=
value
MM_WEIGHT_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
=
Register
()
...
...
@@ -54,3 +66,6 @@ TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER
=
Register
()
EMBEDDING_WEIGHT_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
.
merge
(
PLATFORM_ATTN_WEIGHT_REGISTER
)
MM_WEIGHT_REGISTER
.
merge
(
PLATFORM_MM_WEIGHT_REGISTER
)
lightx2v/utils/set_config.py
View file @
b50498fa
...
...
@@ -8,6 +8,7 @@ from torch.distributed.tensor.device_mesh import init_device_mesh
from
lightx2v.utils.input_info
import
ALL_INPUT_INFO_KEYS
from
lightx2v.utils.lockable_dict
import
LockableDict
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
get_default_config
():
...
...
@@ -92,8 +93,7 @@ def set_parallel_config(config):
cfg_p_size
=
config
[
"parallel"
].
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
device_str
=
config
.
get
(
"run_device"
,
"cuda"
)
config
[
"device_mesh"
]
=
init_device_mesh
(
device_str
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
config
[
"device_mesh"
]
=
init_device_mesh
(
AI_DEVICE
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
if
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"seq_p_size"
,
False
)
and
config
[
"parallel"
][
"seq_p_size"
]
>
1
:
config
[
"seq_parallel"
]
=
True
...
...
@@ -101,7 +101,7 @@ def set_parallel_config(config):
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"cfg_p_size"
,
False
)
and
config
[
"parallel"
][
"cfg_p_size"
]
>
1
:
config
[
"cfg_parallel"
]
=
True
# warmup dist
_a
=
torch
.
zeros
([
1
]).
to
(
f
"
{
device_str
}
:
{
dist
.
get_rank
()
}
"
)
_a
=
torch
.
zeros
([
1
]).
to
(
f
"
{
AI_DEVICE
}
:
{
dist
.
get_rank
()
}
"
)
dist
.
all_reduce
(
_a
)
...
...
lightx2v/utils/utils.py
View file @
b50498fa
...
...
@@ -13,18 +13,18 @@ import torchvision
from
einops
import
rearrange
from
loguru
import
logger
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
seed_all
(
seed
):
random
.
seed
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
manual_seed
(
seed
)
torch
.
mlu
.
manual_seed_all
(
seed
)
torch_device_module
.
manual_seed
(
seed
)
torch_device_module
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
...
...
lightx2v_platform/__init__.py
0 → 100755
View file @
b50498fa
from
.base
import
*
lightx2v_platform/base/__init__.py
0 → 100755
View file @
b50498fa
from
lightx2v_platform.base.base
import
check_ai_device
,
init_ai_device
from
lightx2v_platform.base.cambricon_mlu
import
MluDevice
from
lightx2v_platform.base.metax
import
MetaxDevice
from
lightx2v_platform.base.nvidia
import
CudaDevice
__all__
=
[
"init_ai_device"
,
"check_ai_device"
,
"CudaDevice"
,
"MluDevice"
,
"MetaxDevice"
]
lightx2v_platform/base/base.py
0 → 100644
View file @
b50498fa
from
loguru
import
logger
from
lightx2v_platform.base
import
global_var
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
def
init_ai_device
(
platform
=
"cuda"
):
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
platform
,
None
)
if
platform_device
is
None
:
available_platforms
=
list
(
PLATFORM_DEVICE_REGISTER
.
keys
())
raise
RuntimeError
(
f
"Unsupported platform:
{
platform
}
. Available platforms:
{
available_platforms
}
"
)
global_var
.
AI_DEVICE
=
platform_device
.
get_device
()
logger
.
info
(
f
"Initialized AI_DEVICE:
{
global_var
.
AI_DEVICE
}
"
)
return
global_var
.
AI_DEVICE
def
check_ai_device
(
platform
=
"cuda"
):
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
platform
,
None
)
if
platform_device
is
None
:
available_platforms
=
list
(
PLATFORM_DEVICE_REGISTER
.
keys
())
raise
RuntimeError
(
f
"Unsupported platform:
{
platform
}
. Available platforms:
{
available_platforms
}
"
)
is_available
=
platform_device
.
is_available
()
if
not
is_available
:
raise
RuntimeError
(
f
"AI device for platform '
{
platform
}
' is not available. Please check your runtime environment."
)
logger
.
info
(
f
"AI device for platform '
{
platform
}
' is available."
)
return
True
lightx2v_platform/base/cambricon_mlu.py
0 → 100644
View file @
b50498fa
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"mlu"
)
class
MluDevice
:
name
=
"mlu"
@
staticmethod
def
is_available
()
->
bool
:
try
:
import
torch_mlu
return
torch_mlu
.
mlu
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
return
"mlu"
@
staticmethod
def
init_parallel_env
():
dist
.
init_process_group
(
backend
=
"cncl"
)
torch
.
mlu
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/base/global_var.py
0 → 100644
View file @
b50498fa
AI_DEVICE
=
None
lightx2v_platform/base/metax.py
0 → 100644
View file @
b50498fa
from
lightx2v_platform.base.nvidia
import
CudaDevice
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"metax"
)
class
MetaxDevice
(
CudaDevice
):
name
=
"cuda"
lightx2v_platform/base/nvidia.py
0 → 100644
View file @
b50498fa
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
@
PLATFORM_DEVICE_REGISTER
(
"cuda"
)
class
CudaDevice
:
name
=
"cuda"
@
staticmethod
def
is_available
()
->
bool
:
try
:
import
torch
return
torch
.
cuda
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
return
"cuda"
@
staticmethod
def
init_parallel_env
():
if
ProcessGroupNCCL
is
None
:
raise
RuntimeError
(
"ProcessGroupNCCL is not available. Please check your runtime environment."
)
pg_options
=
ProcessGroupNCCL
.
Options
()
pg_options
.
is_high_priority_stream
=
True
dist
.
init_process_group
(
backend
=
"nccl"
,
pg_options
=
pg_options
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/ops/__init__.py
0 → 100755
View file @
b50498fa
from
lightx2v_platform.base.global_var
import
AI_DEVICE
if
AI_DEVICE
==
"mlu"
:
from
.attn.cambricon_mlu
import
*
from
.mm.cambricon_mlu
import
*
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment