Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
64 deletions
+183
-64
lightx2v/models/schedulers/qwen_image/scheduler.py
lightx2v/models/schedulers/qwen_image/scheduler.py
+6
-6
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+4
-3
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+6
-4
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+5
-5
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+7
-7
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+2
-1
lightx2v/models/video_encoders/hf/qwen_image/vae.py
lightx2v/models/video_encoders/hf/qwen_image/vae.py
+3
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+10
-13
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+7
-15
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+15
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+3
-3
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+6
-6
lightx2v_platform/__init__.py
lightx2v_platform/__init__.py
+1
-0
lightx2v_platform/base/__init__.py
lightx2v_platform/base/__init__.py
+6
-0
lightx2v_platform/base/base.py
lightx2v_platform/base/base.py
+26
-0
lightx2v_platform/base/cambricon_mlu.py
lightx2v_platform/base/cambricon_mlu.py
+27
-0
lightx2v_platform/base/global_var.py
lightx2v_platform/base/global_var.py
+1
-0
lightx2v_platform/base/metax.py
lightx2v_platform/base/metax.py
+7
-0
lightx2v_platform/base/nvidia.py
lightx2v_platform/base/nvidia.py
+36
-0
lightx2v_platform/ops/__init__.py
lightx2v_platform/ops/__init__.py
+5
-0
No files found.
lightx2v/models/schedulers/qwen_image/scheduler.py
View file @
b50498fa
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
diffusers.schedulers.scheduling_flow_match_euler_discrete
import
FlowMatchEulerDiscreteScheduler
from
diffusers.schedulers.scheduling_flow_match_euler_discrete
import
FlowMatchEulerDiscreteScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
calculate_shift
(
def
calculate_shift
(
...
@@ -133,7 +134,6 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -133,7 +134,6 @@ class QwenImageScheduler(BaseScheduler):
self
.
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
))
self
.
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
))
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
,
"scheduler_config.json"
),
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
,
"scheduler_config.json"
),
"r"
)
as
f
:
self
.
scheduler_config
=
json
.
load
(
f
)
self
.
scheduler_config
=
json
.
load
(
f
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
guidance_scale
=
1.0
self
.
guidance_scale
=
1.0
...
@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
shape
=
input_info
.
target_shape
shape
=
input_info
.
target_shape
width
,
height
=
shape
[
-
1
],
shape
[
-
2
]
width
,
height
=
shape
[
-
1
],
shape
[
-
2
]
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
latents
=
self
.
_pack_latents
(
latents
,
self
.
config
[
"batchsize"
],
self
.
config
[
"num_channels_latents"
],
height
,
width
)
latents
=
self
.
_pack_latents
(
latents
,
self
.
config
[
"batchsize"
],
self
.
config
[
"num_channels_latents"
],
height
,
width
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
self
.
run_device
,
self
.
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
AI_DEVICE
,
self
.
dtype
)
self
.
latents
=
latents
self
.
latents
=
latents
self
.
latent_image_ids
=
latent_image_ids
self
.
latent_image_ids
=
latent_image_ids
...
@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
self
.
scheduler
,
num_inference_steps
,
num_inference_steps
,
self
.
run_device
,
AI_DEVICE
,
sigmas
=
sigmas
,
sigmas
=
sigmas
,
mu
=
mu
,
mu
=
mu
,
)
)
...
@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
def
prepare_guidance
(
self
):
def
prepare_guidance
(
self
):
# handle guidance
# handle guidance
if
self
.
config
[
"guidance_embeds"
]:
if
self
.
config
[
"guidance_embeds"
]:
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
self
.
run_device
,
dtype
=
torch
.
float32
)
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
AI_DEVICE
,
dtype
=
torch
.
float32
)
guidance
=
guidance
.
expand
(
self
.
latents
.
shape
[
0
])
guidance
=
guidance
.
expand
(
self
.
latents
.
shape
[
0
])
else
:
else
:
guidance
=
None
guidance
=
None
...
@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
...
@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if
self
.
config
[
"task"
]
==
"i2i"
:
if
self
.
config
[
"task"
]
==
"i2i"
:
self
.
generator
=
torch
.
Generator
().
manual_seed
(
input_info
.
seed
)
self
.
generator
=
torch
.
Generator
().
manual_seed
(
input_info
.
seed
)
elif
self
.
config
[
"task"
]
==
"t2i"
:
elif
self
.
config
[
"task"
]
==
"t2i"
:
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
input_info
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
input_info
.
seed
)
self
.
prepare_latents
(
input_info
)
self
.
prepare_latents
(
input_info
)
self
.
prepare_guidance
()
self
.
prepare_guidance
()
self
.
set_timesteps
()
self
.
set_timesteps
()
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
b50498fa
...
@@ -7,6 +7,7 @@ from loguru import logger
...
@@ -7,6 +7,7 @@ from loguru import logger
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
masks_like
from
lightx2v.utils.utils
import
masks_like
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
EulerScheduler
(
WanScheduler
):
class
EulerScheduler
(
WanScheduler
):
...
@@ -58,14 +59,14 @@ class EulerScheduler(WanScheduler):
...
@@ -58,14 +59,14 @@ class EulerScheduler(WanScheduler):
)
)
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
...
@@ -77,7 +78,7 @@ class EulerScheduler(WanScheduler):
...
@@ -77,7 +78,7 @@ class EulerScheduler(WanScheduler):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
run_device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
AI_DEVICE
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
b50498fa
import
torch
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanScheduler4ChangingResolutionInterface
:
class
WanScheduler4ChangingResolutionInterface
:
def
__new__
(
cls
,
father_scheduler
,
config
):
def
__new__
(
cls
,
father_scheduler
,
config
):
...
@@ -20,7 +22,7 @@ class WanScheduler4ChangingResolution:
...
@@ -20,7 +22,7 @@ class WanScheduler4ChangingResolution:
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents_list
=
[]
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
self
.
latents_list
.
append
(
...
@@ -30,7 +32,7 @@ class WanScheduler4ChangingResolution:
...
@@ -30,7 +32,7 @@ class WanScheduler4ChangingResolution:
int
(
latent_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
)
)
...
@@ -43,7 +45,7 @@ class WanScheduler4ChangingResolution:
...
@@ -43,7 +45,7 @@ class WanScheduler4ChangingResolution:
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
)
)
...
@@ -83,7 +85,7 @@ class WanScheduler4ChangingResolution:
...
@@ -83,7 +85,7 @@ class WanScheduler4ChangingResolution:
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
b50498fa
...
@@ -7,12 +7,12 @@ from torch.nn import functional as F
...
@@ -7,12 +7,12 @@ from torch.nn import functional as F
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.utils
import
masks_like
from
lightx2v.utils.utils
import
masks_like
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanScheduler
(
BaseScheduler
):
class
WanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
@@ -36,7 +36,7 @@ class WanScheduler(BaseScheduler):
...
@@ -36,7 +36,7 @@ class WanScheduler(BaseScheduler):
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
],
],
dim
=
1
,
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
).
to
(
torch
.
device
(
AI_DEVICE
))
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
assert
dim
%
2
==
0
...
@@ -70,7 +70,7 @@ class WanScheduler(BaseScheduler):
...
@@ -70,7 +70,7 @@ class WanScheduler(BaseScheduler):
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
...
@@ -114,14 +114,14 @@ class WanScheduler(BaseScheduler):
...
@@ -114,14 +114,14 @@ class WanScheduler(BaseScheduler):
return
cos_sin
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
b50498fa
...
@@ -2,12 +2,12 @@ import torch
...
@@ -2,12 +2,12 @@ import torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanSFScheduler
(
WanScheduler
):
class
WanSFScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
...
@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
...
@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
self
.
context_noise
=
0
self
.
context_noise
=
0
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
timesteps
=
[]
timesteps
=
[]
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
frame_steps
=
[]
frame_steps
=
[]
for
step_index
,
current_timestep
in
enumerate
(
self
.
denoising_step_list
):
for
step_index
,
current_timestep
in
enumerate
(
self
.
denoising_step_list
):
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
self
.
run_device
,
dtype
=
torch
.
int64
)
*
current_timestep
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
AI_DEVICE
,
dtype
=
torch
.
int64
)
*
current_timestep
frame_steps
.
append
(
timestep
)
frame_steps
.
append
(
timestep
)
timesteps
.
append
(
frame_steps
)
timesteps
.
append
(
frame_steps
)
self
.
timesteps
=
timesteps
self
.
timesteps
=
timesteps
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
self
.
run_device
,
dtype
=
self
.
dtype
)
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
AI_DEVICE
,
dtype
=
self
.
dtype
)
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
if
self
.
extra_one_step
:
if
self
.
extra_one_step
:
...
@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
...
@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
self
.
sigmas_sf
=
self
.
sf_shift
*
self
.
sigmas_sf
/
(
1
+
(
self
.
sf_shift
-
1
)
*
self
.
sigmas_sf
)
self
.
sigmas_sf
=
self
.
sf_shift
*
self
.
sigmas_sf
/
(
1
+
(
self
.
sf_shift
-
1
)
*
self
.
sigmas_sf
)
if
self
.
reverse_sigmas
:
if
self
.
reverse_sigmas
:
self
.
sigmas_sf
=
1
-
self
.
sigmas_sf
self
.
sigmas_sf
=
1
-
self
.
sigmas_sf
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
self
.
run_device
)
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
AI_DEVICE
)
self
.
timesteps_sf
=
self
.
sigmas_sf
*
self
.
num_train_timesteps
self
.
timesteps_sf
=
self
.
sigmas_sf
*
self
.
num_train_timesteps
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
self
.
run_device
)
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
AI_DEVICE
)
self
.
stream_output
=
None
self
.
stream_output
=
None
...
@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
...
@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
# add noise
# add noise
if
self
.
step_index
<
self
.
infer_steps
-
1
:
if
self
.
step_index
<
self
.
infer_steps
-
1
:
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
run_device
,
dtype
=
torch
.
long
)
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
AI_DEVICE
,
dtype
=
torch
.
long
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
noise_next
=
torch
.
randn_like
(
x0_pred
)
noise_next
=
torch
.
randn_like
(
x0_pred
)
...
...
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
b50498fa
...
@@ -4,6 +4,7 @@ from typing import Union
...
@@ -4,6 +4,7 @@ from typing import Union
import
torch
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanStepDistillScheduler
(
WanScheduler
):
class
WanStepDistillScheduler
(
WanScheduler
):
...
@@ -19,7 +20,7 @@ class WanStepDistillScheduler(WanScheduler):
...
@@ -19,7 +20,7 @@ class WanStepDistillScheduler(WanScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
set_denoising_timesteps
(
device
=
self
.
run_device
)
self
.
set_denoising_timesteps
(
device
=
AI_DEVICE
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
...
...
lightx2v/models/video_encoders/hf/qwen_image/vae.py
View file @
b50498fa
...
@@ -5,6 +5,8 @@ from typing import Optional
...
@@ -5,6 +5,8 @@ from typing import Optional
import
torch
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
try
:
from
diffusers
import
AutoencoderKLQwenImage
from
diffusers
import
AutoencoderKLQwenImage
from
diffusers.image_processor
import
VaeImageProcessor
from
diffusers.image_processor
import
VaeImageProcessor
...
@@ -33,7 +35,7 @@ class AutoencoderKLQwenImageVAE:
...
@@ -33,7 +35,7 @@ class AutoencoderKLQwenImageVAE:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
latent_channels
=
config
[
"vae_z_dim"
]
self
.
latent_channels
=
config
[
"vae_z_dim"
]
self
.
load
()
self
.
load
()
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
b50498fa
...
@@ -8,6 +8,10 @@ from einops import rearrange
...
@@ -8,6 +8,10 @@ from einops import rearrange
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
__all__
=
[
__all__
=
[
"WanVAE"
,
"WanVAE"
,
...
@@ -821,11 +825,9 @@ class WanVAE:
...
@@ -821,11 +825,9 @@ class WanVAE:
use_2d_split
=
True
,
use_2d_split
=
True
,
load_from_rank0
=
False
,
load_from_rank0
=
False
,
use_lightvae
=
False
,
use_lightvae
=
False
,
run_device
=
torch
.
device
(
"cuda"
),
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
run_device
=
run_device
self
.
parallel
=
parallel
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
self
.
use_tiling
=
use_tiling
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
...
@@ -955,11 +957,11 @@ class WanVAE:
...
@@ -955,11 +957,11 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
self
.
run_device
)
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
AI_DEVICE
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
self
.
run_device
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
AI_DEVICE
)
self
.
model
=
self
.
model
.
to
(
self
.
run_device
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
self
.
mean
=
self
.
mean
.
cuda
(
)
self
.
mean
=
self
.
mean
.
to
(
AI_DEVICE
)
self
.
inv_std
=
self
.
inv_std
.
cuda
(
)
self
.
inv_std
=
self
.
inv_std
.
to
(
AI_DEVICE
)
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
def
encode_dist
(
self
,
video
,
world_size
,
cur_rank
,
split_dim
):
...
@@ -1330,9 +1332,4 @@ class WanVAE:
...
@@ -1330,9 +1332,4 @@ class WanVAE:
def
device_synchronize
(
def
device_synchronize
(
self
,
self
,
):
):
if
"cuda"
in
str
(
self
.
run_device
):
torch_device_module
.
synchronize
()
torch
.
cuda
.
synchronize
()
elif
"mlu"
in
str
(
self
.
run_device
):
torch
.
mlu
.
synchronize
()
elif
"npu"
in
str
(
self
.
run_device
):
torch
.
npu
.
synchronize
()
lightx2v/utils/profiler.py
View file @
b50498fa
...
@@ -7,6 +7,9 @@ import torch.distributed as dist
...
@@ -7,6 +7,9 @@ import torch.distributed as dist
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
_ProfilingContext
:
class
_ProfilingContext
:
...
@@ -27,12 +30,12 @@ class _ProfilingContext:
...
@@ -27,12 +30,12 @@ class _ProfilingContext:
self
.
metrics_labels
=
metrics_labels
self
.
metrics_labels
=
metrics_labels
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
...
@@ -44,12 +47,12 @@ class _ProfilingContext:
...
@@ -44,12 +47,12 @@ class _ProfilingContext:
return
False
return
False
async
def
__aenter__
(
self
):
async
def
__aenter__
(
self
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
return
self
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
device_synchronize
()
torch_
device_
module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
...
@@ -78,17 +81,6 @@ class _ProfilingContext:
...
@@ -78,17 +81,6 @@ class _ProfilingContext:
return
sync_wrapper
return
sync_wrapper
def
device_synchronize
(
self
,
):
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
synchronize
()
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
synchronize
()
return
class
_NullContext
:
class
_NullContext
:
# Context manager without decision branch logic overhead
# Context manager without decision branch logic overhead
...
...
lightx2v/utils/registry_factory.py
View file @
b50498fa
from
lightx2v_platform.registry_factory
import
PLATFORM_ATTN_WEIGHT_REGISTER
,
PLATFORM_MM_WEIGHT_REGISTER
class
Register
(
dict
):
class
Register
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Register
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
Register
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
@@ -43,6 +46,15 @@ class Register(dict):
...
@@ -43,6 +46,15 @@ class Register(dict):
def
items
(
self
):
def
items
(
self
):
return
self
.
_dict
.
items
()
return
self
.
_dict
.
items
()
def
get
(
self
,
key
,
default
=
None
):
return
self
.
_dict
.
get
(
key
,
default
)
def
merge
(
self
,
other_register
):
for
key
,
value
in
other_register
.
items
():
if
key
in
self
.
_dict
:
raise
Exception
(
f
"
{
key
}
already exists in target register."
)
self
[
key
]
=
value
MM_WEIGHT_REGISTER
=
Register
()
MM_WEIGHT_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
=
Register
()
...
@@ -54,3 +66,6 @@ TENSOR_REGISTER = Register()
...
@@ -54,3 +66,6 @@ TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER
=
Register
()
CONVERT_WEIGHT_REGISTER
=
Register
()
EMBEDDING_WEIGHT_REGISTER
=
Register
()
EMBEDDING_WEIGHT_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
ATTN_WEIGHT_REGISTER
.
merge
(
PLATFORM_ATTN_WEIGHT_REGISTER
)
MM_WEIGHT_REGISTER
.
merge
(
PLATFORM_MM_WEIGHT_REGISTER
)
lightx2v/utils/set_config.py
View file @
b50498fa
...
@@ -8,6 +8,7 @@ from torch.distributed.tensor.device_mesh import init_device_mesh
...
@@ -8,6 +8,7 @@ from torch.distributed.tensor.device_mesh import init_device_mesh
from
lightx2v.utils.input_info
import
ALL_INPUT_INFO_KEYS
from
lightx2v.utils.input_info
import
ALL_INPUT_INFO_KEYS
from
lightx2v.utils.lockable_dict
import
LockableDict
from
lightx2v.utils.lockable_dict
import
LockableDict
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
get_default_config
():
def
get_default_config
():
...
@@ -92,8 +93,7 @@ def set_parallel_config(config):
...
@@ -92,8 +93,7 @@ def set_parallel_config(config):
cfg_p_size
=
config
[
"parallel"
].
get
(
"cfg_p_size"
,
1
)
cfg_p_size
=
config
[
"parallel"
].
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
seq_p_size
=
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
device_str
=
config
.
get
(
"run_device"
,
"cuda"
)
config
[
"device_mesh"
]
=
init_device_mesh
(
AI_DEVICE
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
config
[
"device_mesh"
]
=
init_device_mesh
(
device_str
,
(
cfg_p_size
,
seq_p_size
),
mesh_dim_names
=
(
"cfg_p"
,
"seq_p"
))
if
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"seq_p_size"
,
False
)
and
config
[
"parallel"
][
"seq_p_size"
]
>
1
:
if
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"seq_p_size"
,
False
)
and
config
[
"parallel"
][
"seq_p_size"
]
>
1
:
config
[
"seq_parallel"
]
=
True
config
[
"seq_parallel"
]
=
True
...
@@ -101,7 +101,7 @@ def set_parallel_config(config):
...
@@ -101,7 +101,7 @@ def set_parallel_config(config):
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"cfg_p_size"
,
False
)
and
config
[
"parallel"
][
"cfg_p_size"
]
>
1
:
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
[
"parallel"
]
and
config
[
"parallel"
].
get
(
"cfg_p_size"
,
False
)
and
config
[
"parallel"
][
"cfg_p_size"
]
>
1
:
config
[
"cfg_parallel"
]
=
True
config
[
"cfg_parallel"
]
=
True
# warmup dist
# warmup dist
_a
=
torch
.
zeros
([
1
]).
to
(
f
"
{
device_str
}
:
{
dist
.
get_rank
()
}
"
)
_a
=
torch
.
zeros
([
1
]).
to
(
f
"
{
AI_DEVICE
}
:
{
dist
.
get_rank
()
}
"
)
dist
.
all_reduce
(
_a
)
dist
.
all_reduce
(
_a
)
...
...
lightx2v/utils/utils.py
View file @
b50498fa
...
@@ -13,18 +13,18 @@ import torchvision
...
@@ -13,18 +13,18 @@ import torchvision
from
einops
import
rearrange
from
einops
import
rearrange
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
seed_all
(
seed
):
def
seed_all
(
seed
):
random
.
seed
(
seed
)
random
.
seed
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch_device_module
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch_device_module
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
elif
hasattr
(
torch
,
"mlu"
)
and
torch
.
mlu
.
is_available
():
torch
.
mlu
.
manual_seed
(
seed
)
torch
.
mlu
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
...
...
lightx2v_platform/__init__.py
0 → 100755
View file @
b50498fa
from
.base
import
*
lightx2v_platform/base/__init__.py
0 → 100755
View file @
b50498fa
from
lightx2v_platform.base.base
import
check_ai_device
,
init_ai_device
from
lightx2v_platform.base.cambricon_mlu
import
MluDevice
from
lightx2v_platform.base.metax
import
MetaxDevice
from
lightx2v_platform.base.nvidia
import
CudaDevice
__all__
=
[
"init_ai_device"
,
"check_ai_device"
,
"CudaDevice"
,
"MluDevice"
,
"MetaxDevice"
]
lightx2v_platform/base/base.py
0 → 100644
View file @
b50498fa
from
loguru
import
logger
from
lightx2v_platform.base
import
global_var
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
def
init_ai_device
(
platform
=
"cuda"
):
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
platform
,
None
)
if
platform_device
is
None
:
available_platforms
=
list
(
PLATFORM_DEVICE_REGISTER
.
keys
())
raise
RuntimeError
(
f
"Unsupported platform:
{
platform
}
. Available platforms:
{
available_platforms
}
"
)
global_var
.
AI_DEVICE
=
platform_device
.
get_device
()
logger
.
info
(
f
"Initialized AI_DEVICE:
{
global_var
.
AI_DEVICE
}
"
)
return
global_var
.
AI_DEVICE
def
check_ai_device
(
platform
=
"cuda"
):
platform_device
=
PLATFORM_DEVICE_REGISTER
.
get
(
platform
,
None
)
if
platform_device
is
None
:
available_platforms
=
list
(
PLATFORM_DEVICE_REGISTER
.
keys
())
raise
RuntimeError
(
f
"Unsupported platform:
{
platform
}
. Available platforms:
{
available_platforms
}
"
)
is_available
=
platform_device
.
is_available
()
if
not
is_available
:
raise
RuntimeError
(
f
"AI device for platform '
{
platform
}
' is not available. Please check your runtime environment."
)
logger
.
info
(
f
"AI device for platform '
{
platform
}
' is available."
)
return
True
lightx2v_platform/base/cambricon_mlu.py
0 → 100644
View file @
b50498fa
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"mlu"
)
class
MluDevice
:
name
=
"mlu"
@
staticmethod
def
is_available
()
->
bool
:
try
:
import
torch_mlu
return
torch_mlu
.
mlu
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
return
"mlu"
@
staticmethod
def
init_parallel_env
():
dist
.
init_process_group
(
backend
=
"cncl"
)
torch
.
mlu
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/base/global_var.py
0 → 100644
View file @
b50498fa
AI_DEVICE
=
None
lightx2v_platform/base/metax.py
0 → 100644
View file @
b50498fa
from
lightx2v_platform.base.nvidia
import
CudaDevice
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
@
PLATFORM_DEVICE_REGISTER
(
"metax"
)
class
MetaxDevice
(
CudaDevice
):
name
=
"cuda"
lightx2v_platform/base/nvidia.py
0 → 100644
View file @
b50498fa
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.registry_factory
import
PLATFORM_DEVICE_REGISTER
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
@
PLATFORM_DEVICE_REGISTER
(
"cuda"
)
class
CudaDevice
:
name
=
"cuda"
@
staticmethod
def
is_available
()
->
bool
:
try
:
import
torch
return
torch
.
cuda
.
is_available
()
except
ImportError
:
return
False
@
staticmethod
def
get_device
()
->
str
:
return
"cuda"
@
staticmethod
def
init_parallel_env
():
if
ProcessGroupNCCL
is
None
:
raise
RuntimeError
(
"ProcessGroupNCCL is not available. Please check your runtime environment."
)
pg_options
=
ProcessGroupNCCL
.
Options
()
pg_options
.
is_high_priority_stream
=
True
dist
.
init_process_group
(
backend
=
"nccl"
,
pg_options
=
pg_options
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
lightx2v_platform/ops/__init__.py
0 → 100755
View file @
b50498fa
from
lightx2v_platform.base.global_var
import
AI_DEVICE
if
AI_DEVICE
==
"mlu"
:
from
.attn.cambricon_mlu
import
*
from
.mm.cambricon_mlu
import
*
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment