Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
"router/vscode:/vscode.git/clone" did not exist on "cb3ae30284ada6d15822a4ccde9156b8e93ef2b6"
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
64 deletions
+183
-64
lightx2v/models/schedulers/qwen_image/scheduler.py
lightx2v/models/schedulers/qwen_image/scheduler.py
+6
-6
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+4
-3
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+6
-4
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+5
-5
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+7
-7
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+2
-1
lightx2v/models/video_encoders/hf/qwen_image/vae.py
lightx2v/models/video_encoders/hf/qwen_image/vae.py
+3
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+10
-13
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+7
-15
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+15
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+3
-3
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+6
-6
lightx2v_platform/__init__.py
lightx2v_platform/__init__.py
+1
-0
lightx2v_platform/base/__init__.py
lightx2v_platform/base/__init__.py
+6
-0
lightx2v_platform/base/base.py
lightx2v_platform/base/base.py
+26
-0
lightx2v_platform/base/cambricon_mlu.py
lightx2v_platform/base/cambricon_mlu.py
+27
-0
lightx2v_platform/base/global_var.py
lightx2v_platform/base/global_var.py
+1
-0
lightx2v_platform/base/metax.py
lightx2v_platform/base/metax.py
+7
-0
lightx2v_platform/base/nvidia.py
lightx2v_platform/base/nvidia.py
+36
-0
lightx2v_platform/ops/__init__.py
lightx2v_platform/ops/__init__.py
+5
-0
No files found.
lightx2v/models/schedulers/qwen_image/scheduler.py
View file @
b50498fa
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
diffusers.schedulers.scheduling_flow_match_euler_discrete
import
FlowMatchEulerDiscreteScheduler
from
diffusers.schedulers.scheduling_flow_match_euler_discrete
import
FlowMatchEulerDiscreteScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
calculate_shift
(
def
calculate_shift
(
...
@@ -133,7 +134,6 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -133,7 +134,6 @@ class QwenImageScheduler(BaseScheduler):
self
.
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
))
self
.
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
))
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
,
"scheduler_config.json"
),
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
,
"scheduler_config.json"
),
"r"
)
as
f
:
self
.
scheduler_config
=
json
.
load
(
f
)
self
.
scheduler_config
=
json
.
load
(
f
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
guidance_scale
=
1.0
self
.
guidance_scale
=
1.0
...
@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
shape
=
input_info
.
target_shape
shape
=
input_info
.
target_shape
width
,
height
=
shape
[
-
1
],
shape
[
-
2
]
width
,
height
=
shape
[
-
1
],
shape
[
-
2
]
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
latents
=
self
.
_pack_latents
(
latents
,
self
.
config
[
"batchsize"
],
self
.
config
[
"num_channels_latents"
],
height
,
width
)
latents
=
self
.
_pack_latents
(
latents
,
self
.
config
[
"batchsize"
],
self
.
config
[
"num_channels_latents"
],
height
,
width
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
self
.
run_device
,
self
.
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
AI_DEVICE
,
self
.
dtype
)
self
.
latents
=
latents
self
.
latents
=
latents
self
.
latent_image_ids
=
latent_image_ids
self
.
latent_image_ids
=
latent_image_ids
...
@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
self
.
scheduler
,
num_inference_steps
,
num_inference_steps
,
self
.
run_device
,
AI_DEVICE
,
sigmas
=
sigmas
,
sigmas
=
sigmas
,
mu
=
mu
,
mu
=
mu
,
)
)
...
@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
def
prepare_guidance
(
self
):
def
prepare_guidance
(
self
):
# handle guidance
# handle guidance
if
self
.
config
[
"guidance_embeds"
]:
if
self
.
config
[
"guidance_embeds"
]:
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
self
.
run_device
,
dtype
=
torch
.
float32
)
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
AI_DEVICE
,
dtype
=
torch
.
float32
)
guidance
=
guidance
.
expand
(
self
.
latents
.
shape
[
0
])
guidance
=
guidance
.
expand
(
self
.
latents
.
shape
[
0
])
else
:
else
:
guidance
=
None
guidance
=
None
...
@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if
self
.
config
[
"task"
]
==
"i2i"
:
if
self
.
config
[
"task"
]
==
"i2i"
:
self
.
generator
=
torch
.
Generator
().
manual_seed
(
input_info
.
seed
)
self
.
generator
=
torch
.
Generator
().
manual_seed
(
input_info
.
seed
)
elif
self
.
config
[
"task"
]
==
"t2i"
:
elif
self
.
config
[
"task"
]
==
"t2i"
:
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
input_info
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
input_info
.
seed
)
self
.
prepare_latents
(
input_info
)
self
.
prepare_latents
(
input_info
)
self
.
prepare_guidance
()
self
.
prepare_guidance
()
self
.
set_timesteps
()
self
.
set_timesteps
()
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
b50498fa
...
@@ -7,6 +7,7 @@ from loguru import logger
...
@@ -7,6 +7,7 @@ from loguru import logger
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
masks_like
from
lightx2v.utils.utils
import
masks_like
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
EulerScheduler
(
WanScheduler
):
class
EulerScheduler
(
WanScheduler
):
...
@@ -58,14 +59,14 @@ class EulerScheduler(WanScheduler):
...
@@ -58,14 +59,14 @@ class EulerScheduler(WanScheduler):
)
)
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
...
@@ -77,7 +78,7 @@ class EulerScheduler(WanScheduler):
...
@@ -77,7 +78,7 @@ class EulerScheduler(WanScheduler):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
run_device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
AI_DEVICE
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
b50498fa
import
torch
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanScheduler4ChangingResolutionInterface
:
class
WanScheduler4ChangingResolutionInterface
:
def
__new__
(
cls
,
father_scheduler
,
config
):
def
__new__
(
cls
,
father_scheduler
,
config
):
...
@@ -20,7 +22,7 @@ class WanScheduler4ChangingResolution:
...
@@ -20,7 +22,7 @@ class WanScheduler4ChangingResolution:
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents_list
=
[]
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
self
.
latents_list
.
append
(
...
@@ -30,7 +32,7 @@ class WanScheduler4ChangingResolution:
...
@@ -30,7 +32,7 @@ class WanScheduler4ChangingResolution:
int
(
latent_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
)
)
...
@@ -43,7 +45,7 @@ class WanScheduler4ChangingResolution:
...
@@ -43,7 +45,7 @@ class WanScheduler4ChangingResolution:
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
)
)
...
@@ -83,7 +85,7 @@ class WanScheduler4ChangingResolution:
...
@@ -83,7 +85,7 @@ class WanScheduler4ChangingResolution:
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
b50498fa
...
@@ -7,12 +7,12 @@ from torch.nn import functional as F
...
@@ -7,12 +7,12 @@ from torch.nn import functional as F
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.utils
import
masks_like
from
lightx2v.utils.utils
import
masks_like
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanScheduler
(
BaseScheduler
):
class
WanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
@@ -36,7 +36,7 @@ class WanScheduler(BaseScheduler):
...
@@ -36,7 +36,7 @@ class WanScheduler(BaseScheduler):
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
],
],
dim
=
1
,
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
).
to
(
torch
.
device
(
AI_DEVICE
))
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
assert
dim
%
2
==
0
...
@@ -70,7 +70,7 @@ class WanScheduler(BaseScheduler):
...
@@ -70,7 +70,7 @@ class WanScheduler(BaseScheduler):
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
...
@@ -114,14 +114,14 @@ class WanScheduler(BaseScheduler):
...
@@ -114,14 +114,14 @@ class WanScheduler(BaseScheduler):
return
cos_sin
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
b50498fa
...
@@ -2,12 +2,12 @@ import torch
...
@@ -2,12 +2,12 @@ import torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanSFScheduler
(
WanScheduler
):
class
WanSFScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
...
@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
...
@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
self
.
context_noise
=
0
self
.
context_noise
=
0
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
timesteps
=
[]
timesteps
=
[]
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
frame_steps
=
[]
frame_steps
=
[]
for
step_index
,
current_timestep
in
enumerate
(
self
.
denoising_step_list
):
for
step_index
,
current_timestep
in
enumerate
(
self
.
denoising_step_list
):
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
self
.
run_device
,
dtype
=
torch
.
int64
)
*
current_timestep
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
AI_DEVICE
,
dtype
=
torch
.
int64
)
*
current_timestep
frame_steps
.
append
(
timestep
)
frame_steps
.
append
(
timestep
)
timesteps
.
append
(
frame_steps
)
timesteps
.
append
(
frame_steps
)
self
.
timesteps
=
timesteps
self
.
timesteps
=
timesteps
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
if
self
.
extra_one_step
:
if
self
.
extra_one_step
:
...
@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
...
@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
self
.
sigmas_sf
=
self
.
sf_shift
*
self
.
sigmas_sf
/
(
1
+
(
self
.
sf_shift
-
1
)
*
self
.
sigmas_sf
)
self
.
sigmas_sf
=
self
.
sf_shift
*
self
.
sigmas_sf
/
(
1
+
(
self
.
sf_shift
-
1
)
*
self
.
sigmas_sf
)
if
self
.
reverse_sigmas
:
if
self
.
reverse_sigmas
:
self
.
sigmas_sf
=
1
-
self
.
sigmas_sf
self
.
sigmas_sf
=
1
-
self
.
sigmas_sf
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
self
.
run_device
)
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
AI_DEVICE
)
self
.
timesteps_sf
=
self
.
sigmas_sf
*
self
.
num_train_timesteps
self
.
timesteps_sf
=
self
.
sigmas_sf
*
self
.
num_train_timesteps
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
self
.
run_device
)
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
AI_DEVICE
)
self
.
stream_output
=
None
self
.
stream_output
=
None
...
@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
...
@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
# add noise
# add noise
if
self
.
step_index
<
self
.
infer_steps
-
1
:
if
self
.
step_index
<
self
.
infer_steps
-
1
:
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
run_device
,
dtype
=
torch
.
long
)
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
AI_DEVICE
,
dtype
=
torch
.
long
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
noise_next
=
torch
.
randn_like
(
x0_pred
)
noise_next
=
torch
.
randn_like
(
x0_pred
)
...
...
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
b50498fa
...
@@ -4,6 +4,7 @@ from typing import Union
...
@@ -4,6 +4,7 @@ from typing import Union
import
torch
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanStepDistillScheduler
(
WanScheduler
):
class
WanStepDistillScheduler
(
WanScheduler
):
...
@@ -19,7 +20,7 @@ class WanStepDistillScheduler(WanScheduler):
...
@@ -19,7 +20,7 @@ class WanStepDistillScheduler(WanScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
set_denoising_timesteps
(
device
=
self
.
run_device
)
self
.
set_denoising_timesteps
(
device
=
AI_DEVICE
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
...
...
lightx2v/models/video_encoders/hf/qwen_image/vae.py
View file @
b50498fa
...
@@ -5,6 +5,8 @@ from typing import Optional
...
@@ -5,6 +5,8 @@ from typing import Optional
import
torch
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
try
:
from
diffusers
import
AutoencoderKLQwenImage
from
diffusers
import
AutoencoderKLQwenImage
from
diffusers.image_processor
import
VaeImageProcessor
from
diffusers.image_processor
import
VaeImageProcessor
...
@@ -33,7 +35,7 @@ class AutoencoderKLQwenImageVAE:
...
@@ -33,7 +35,7 @@ class AutoencoderKLQwenImageVAE:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
latent_channels
=
config
[
"vae_z_dim"
]
self
.
latent_channels
=
config
[
"vae_z_dim"
]
self
.
load
()
self
.
load
()
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
b50498fa
...
@@ -8,6 +8,10 @@ from einops import rearrange
...
@@ -8,6 +8,10 @@ from einops import rearrange
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
__all__
=
[
__all__
=
[
"WanVAE"
,
"WanVAE"
,
...
@@ -821,11 +825,9 @@ class WanVAE:
...
@@ -821,11 +825,9 @@ class WanVAE:
use_2d_split
=
True
,
use_2d_split
=
True
,
load_from_rank0
=
False
,
load_from_rank0
=
False
,
use_lightvae
=
False
,
use_lightvae
=
False
,
run_device
=
torch
.
device
(
"cuda"
),
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
run_device
=
run_device
self
.
parallel
=
parallel
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
self
.
use_tiling
=
use_tiling
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
...
@@ -955,11 +957,11 @@ class WanVAE:
...
@@ -955,11 +957,11 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
self
.
run_device
)
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
AI_DEVICE
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
self
.
run_device
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
AI_DEVICE
)
self
.
model
=
self
.
model
.
to
(
self
.
run_device
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
self
.
mean
=
self
.
mean
.
cuda
(
)
self
.
mean
=
self
.
mean
.
to
(
AI_DEVICE
)
self
.
inv_std
=
self
.
inv_std
.
cuda
(
)
self
.
inv_std
=
self
.
inv_std
.
to
(
AI_DEVICE
)
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
...
@@ -1330,9 +1332,4 @@ class WanVAE:
...
@@ -1330,9 +1332,4 @@ class WanVAE:
def
device_synchronize
(
def
device_synchronize
(
self
,
self
,
):
):
if
"cuda"
in
str
(
self
.
run_device
):
torch_device_module
.
synchronize
()
torch
.
cuda
.
synchronize
()
elif
"mlu"
in
str
(
self
.
run_device
):
torch
.
mlu
.
synchronize
()
elif
"npu"
in
str
(
self
.
run_device
):
torch
.
npu
.
synchronize
()
lightx2v/utils/profiler.py
View file @
b50498fa
...
@@ -7,6 +7,9 @@ import torch.distributed as dist
...
@@ -7,6 +7,9 @@ import torch.distributed as dist
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
_ProfilingContext
:
class
_ProfilingContext
:
...
@@ -27,12 +30,12 @@ class _ProfilingContext:
...
@@ -27,12 +30,12 @@ class _ProfilingContext:
self
.
metrics_labels
=
metrics_labels
self
.
metrics_labels
=
metrics_labels
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
...
@@ -44,12 +47,12 @@ class _ProfilingContext:
...
@@ -44,12 +47,12 @@ class _ProfilingContext:
return
False
return
False
async
def
__aenter__
(
self
):
async
def
__aenter__
(
self
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
return
self
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
...
@@ -78,17 +81,6 @@ class _ProfilingContext:
...
@@ -78,17 +81,6 @@ class _ProfilingContext:
return
sync_wrapper
return
sync_wrapper
def
device_synchronize
(
self
,
):
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
synchronize
()
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
synchronize
()
return
class
_NullContext
:
class
_NullContext
:
# Context manager without decision branch logic overhead
# Context manager without decision branch logic overhead
...
...
lightx2v/utils/registry_factory.py
View file @
b50498fa
from
lightx2v_platform.registry_factory
import
PLATFORM_ATTN_WEIGHT_REGISTER
,
PLATFORM_MM_WEIGHT_REGISTER
class
Register
(
dict
):
class
Register
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Register
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
Register
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
@@ -43,6 +46,15 @@ class Register(dict):
...
@@ -43,6 +46,15 @@ class Register(dict):
def
items
(
self
):
def
items
(
self
):
return
self
.
_dict
.
items
()
return
self
.
_dict
.
items
()
def
get
(
self
,
key
,
default
=
None
):
return
self
.
_dict
.
get
(
key
,
default
)
def
merge
(
self
,
other_register
):
for
key
,
value
in
other_register
.
items
():
if
key
in
self
.
_dict
:
raise
Exception
(
f
"
{
key
}
already exists in target register."
)
self
[
key
]
=
value
MM_WEIGHT_REGISTER
=
Register
()
MM_WEIGHT_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
=
Register
()
...
@@ -54,3 +66,6 @@ TENSOR_REGISTER = Register()
...
@@ -54,3 +66,6 @@ TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER
=
Register
()
CONVERT_WEIGHT_REGISTER
=
Register
()
EMBEDDING_WEIGHT_REGISTER
=
Register
()
EMBEDDING_WEIGHT_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
.
merge
(
PLATFORM_ATTN_WEIGHT_REGISTER
)
MM_WEIGHT_REGISTER
.
merge
(
PLATFORM_MM_WEIGHT_REGISTER
)
lightx2v/utils/set_config.py
View file @
b50498fa
...
@@ -8,6 +8,7 @@ from torch.distributed.tensor.device_mesh import init_device_mesh
...
@@ -8,6 +8,7 @@ from torch.distributed.tensor.device_mesh import init_device_mesh
from
lightx2v.utils.input_info
import
ALL_INPUT_INFO_KEYS
from
lightx2v.utils.input_info
import
ALL_INPUT_INFO_KEYS
from
lightx2v.utils.lockable_dict
import
LockableDict
from
lightx2v.utils.lockable_dict
import
LockableDict
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
get_default_config
():
def
get_default_config
():
...
@@ -92,8 +93,7 @@ def set_parallel_config(config):
...
@@ -92,8 +93,7 @@ def set_parallel_config(config):
cfg_p_size
=
config
[
"parallel"
].
get
(
"cfg_p_size"
,
1
)
cfg_p_size
=
config
[
"parallel"
].
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
seq_p_size
=
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
device_str
=
config
.
get
(
"run_device"
,
"cuda"
)
config
[
"device_mesh"
]
=
init_device_mesh
(
AI_DEVICE
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
config
[
"device_mesh"
]
=
init_device_mesh
(
device_str
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
if
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"seq_p_size"
,
False
)
and
config
[
"parallel"
][
"seq_p_size"
]
>
1
:
if
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"seq_p_size"
,
False
)
and
config
[
"parallel"
][
"seq_p_size"
]
>
1
:
config
[
"seq_parallel"
]
=
True
config
[
"seq_parallel"
]
=
True
...
@@ -101,7 +101,7 @@ def set_parallel_config(config):
...
@@ -101,7 +101,7 @@ def set_parallel_config(config):
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"cfg_p_size"
,
False
)
and
config
[
"parallel"
][
"cfg_p_size"
]
>
1
:
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"cfg_p_size"
,
False
)
and
config
[
"parallel"
][
"cfg_p_size"
]
>
1
:
config
[
"cfg_parallel"
]
=
True
config
[
"cfg_parallel"
]
=
True
# warmup dist
# warmup dist
_a
=
torch
.
zeros
([
1
]).
to
(
f
"
{
device_str
}
:
{
dist
.
get_rank
()
}
"
)
_a
=
torch
.
zeros
([
1
]).
to
(
f
"
{
AI_DEVICE
}
:
{
dist
.
get_rank
()
}
"
)
dist
.
all_reduce
(
_a
)
dist
.
all_reduce
(
_a
)
...
...
lightx2v/utils/utils.py
View file @
b50498fa
...
@@ -13,18 +13,18 @@ import torchvision
...
@@ -13,18 +13,18 @@ import torchvision
from
einops
import
rearrange
from
einops
import
rearrange
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
seed_all
(
seed
):
def
seed_all
(
seed
):
random
.
seed
(
seed
)
random
.
seed
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch_device_module
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch_device_module
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
manual_seed
(
seed
)
torch
.
mlu
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
...
...
lightx2v_platform/__init__.py
0 → 100755
View file @
b50498fa
from
.base
import
*
lightx2v_platform/base/__init__.py
0 → 100755
View file @
b50498fa
from
lightx2v_platform.base.base
import
check_ai_device
,
init_ai_device
from
lightx2v_platform.base.cambricon_mlu
import
MluDevice
from
lightx2v_platform.base.metax
import
MetaxDevice
from
lightx2v_platform.base.nvidia
import
CudaDevice
__all__
=
[
"init_ai_device"
,
"check_ai_device"
,
"CudaDevice"
,
"MluDevice"
,
"MetaxDevice"
]
lightx2v_platform/base/base.py
0 → 100644
View file @
b50498fa
from
loguru
import
logger
from
lightx2v_platform.base
import
global_var
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
def
init_ai_device
(
platform
=
"cuda"
):
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
platform
,
None
)
if
platform_device
is
None
:
available_platforms
=
list
(
PLATFORM_DEVICE_REGISTER
.
keys
())
raise
RuntimeError
(
f
"Unsupported platform:
{
platform
}
. Available platforms:
{
available_platforms
}
"
)
global_var
.
AI_DEVICE
=
platform_device
.
get_device
()
logger
.
info
(
f
"Initialized AI_DEVICE:
{
global_var
.
AI_DEVICE
}
"
)
return
global_var
.
AI_DEVICE
def
check_ai_device
(
platform
=
"cuda"
):
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
platform
,
None
)
if
platform_device
is
None
:
available_platforms
=
list
(
PLATFORM_DEVICE_REGISTER
.
keys
())
raise
RuntimeError
(
f
"Unsupported platform:
{
platform
}
. Available platforms:
{
available_platforms
}
"
)
is_available
=
platform_device
.
is_available
()
if
not
is_available
:
raise
RuntimeError
(
f
"AI device for platform '
{
platform
}
' is not available. Please check your runtime environment."
)
logger
.
info
(
f
"AI device for platform '
{
platform
}
' is available."
)
return
True
lightx2v_platform/base/cambricon_mlu.py
0 → 100644
View file @
b50498fa
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"mlu"
)
class
MluDevice
:
name
=
"mlu"
@
staticmethod
def
is_available
()
->
bool
:
try
:
import
torch_mlu
return
torch_mlu
.
mlu
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
return
"mlu"
@
staticmethod
def
init_parallel_env
():
dist
.
init_process_group
(
backend
=
"cncl"
)
torch
.
mlu
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/base/global_var.py
0 → 100644
View file @
b50498fa
AI_DEVICE
=
None
lightx2v_platform/base/metax.py
0 → 100644
View file @
b50498fa
from
lightx2v_platform.base.nvidia
import
CudaDevice
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"metax"
)
class
MetaxDevice
(
CudaDevice
):
name
=
"cuda"
lightx2v_platform/base/nvidia.py
0 → 100644
View file @
b50498fa
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
@
PLATFORM_DEVICE_REGISTER
(
"cuda"
)
class
CudaDevice
:
name
=
"cuda"
@
staticmethod
def
is_available
()
->
bool
:
try
:
import
torch
return
torch
.
cuda
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
return
"cuda"
@
staticmethod
def
init_parallel_env
():
if
ProcessGroupNCCL
is
None
:
raise
RuntimeError
(
"ProcessGroupNCCL is not available. Please check your runtime environment."
)
pg_options
=
ProcessGroupNCCL
.
Options
()
pg_options
.
is_high_priority_stream
=
True
dist
.
init_process_group
(
backend
=
"nccl"
,
pg_options
=
pg_options
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/ops/__init__.py
0 → 100755
View file @
b50498fa
from
lightx2v_platform.base.global_var
import
AI_DEVICE
if
AI_DEVICE
==
"mlu"
:
from
.attn.cambricon_mlu
import
*
from
.mm.cambricon_mlu
import
*
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment