Commit 701075f4 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

refactor compiler (#301)

parent 60c421f4
...@@ -29,6 +29,7 @@ class DefaultRunner(BaseRunner): ...@@ -29,6 +29,7 @@ class DefaultRunner(BaseRunner):
if not self.has_prompt_enhancer: if not self.has_prompt_enhancer:
self.config.use_prompt_enhancer = False self.config.use_prompt_enhancer = False
self.set_init_device() self.set_init_device()
self.init_scheduler()
def init_modules(self): def init_modules(self):
logger.info("Initializing runner modules...") logger.info("Initializing runner modules...")
...@@ -36,6 +37,7 @@ class DefaultRunner(BaseRunner): ...@@ -36,6 +37,7 @@ class DefaultRunner(BaseRunner):
self.load_model() self.load_model()
elif self.config.get("lazy_load", False): elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False) assert self.config.get("cpu_offload", False)
self.model.set_scheduler(self.scheduler) # set scheduler to model
if self.config["task"] == "i2v": if self.config["task"] == "i2v":
self.run_input_encoder = self._run_input_encoder_local_i2v self.run_input_encoder = self._run_input_encoder_local_i2v
elif self.config["task"] == "flf2v": elif self.config["task"] == "flf2v":
...@@ -44,6 +46,9 @@ class DefaultRunner(BaseRunner): ...@@ -44,6 +46,9 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_t2v self.run_input_encoder = self._run_input_encoder_local_t2v
elif self.config["task"] == "vace": elif self.config["task"] == "vace":
self.run_input_encoder = self._run_input_encoder_local_vace self.run_input_encoder = self._run_input_encoder_local_vace
if self.config.get("compile", False):
logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
self.model.compile(self.config.get("compile_shapes", []))
def set_init_device(self): def set_init_device(self):
if self.config.cpu_offload: if self.config.cpu_offload:
...@@ -214,7 +219,6 @@ class DefaultRunner(BaseRunner): ...@@ -214,7 +219,6 @@ class DefaultRunner(BaseRunner):
self.get_video_segment_num() self.get_video_segment_num()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v": if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None self.inputs["image_encoder_output"]["vae_encoder_out"] = None
...@@ -222,6 +226,8 @@ class DefaultRunner(BaseRunner): ...@@ -222,6 +226,8 @@ class DefaultRunner(BaseRunner):
@ProfilingContext4DebugL2("Run DiT") @ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None): def run_main(self, total_steps=None):
self.init_run() self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"): with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
......
from loguru import logger
from lightx2v.utils.profiler import *
class GraphRunner:
def __init__(self, runner):
self.runner = runner
self.compile()
def compile(self):
logger.info("=" * 60)
logger.info("🚀 Starting Model Compilation - Please wait, this may take a while... 🚀")
logger.info("=" * 60)
with ProfilingContext4DebugL2("compile"):
self.runner.run_step()
logger.info("=" * 60)
logger.info("✅ Model Compilation Completed ✅")
logger.info("=" * 60)
def run_pipeline(self):
return self.runner.run_pipeline()
...@@ -309,12 +309,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -309,12 +309,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
scheduler = EulerScheduler(self.config) self.scheduler = EulerScheduler(self.config)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter)
scheduler.set_audio_adapter(self.audio_adapter)
self.model.set_scheduler(scheduler)
def read_audio_input(self): def read_audio_input(self):
"""Read audio input""" """Read audio input"""
...@@ -469,13 +464,14 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -469,13 +464,14 @@ class WanAudioRunner(WanRunner): # type:ignore
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1) mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w) mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1) return mask.transpose(0, 1).contiguous()
def get_video_segment_num(self): def get_video_segment_num(self):
self.video_segment_num = len(self.inputs["audio_segments"]) self.video_segment_num = len(self.inputs["audio_segments"])
def init_run(self): def init_run(self):
super().init_run() super().init_run()
self.scheduler.set_audio_adapter(self.audio_adapter)
self.gen_video_list = [] self.gen_video_list = []
self.cut_audio_list = [] self.cut_audio_list = []
......
...@@ -47,8 +47,7 @@ class WanCausVidRunner(WanRunner): ...@@ -47,8 +47,7 @@ class WanCausVidRunner(WanRunner):
self.num_fragments = self.config["num_fragments"] self.num_fragments = self.config["num_fragments"]
def init_scheduler(self): def init_scheduler(self):
scheduler = WanStepDistillScheduler(self.config) self.scheduler = WanStepDistillScheduler(self.config)
self.model.set_scheduler(scheduler)
def set_target_shape(self): def set_target_shape(self):
if self.config.task == "i2v": if self.config.task == "i2v":
......
...@@ -35,10 +35,9 @@ class WanDistillRunner(WanRunner): ...@@ -35,10 +35,9 @@ class WanDistillRunner(WanRunner):
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.feature_caching == "NoCaching":
scheduler = WanStepDistillScheduler(self.config) self.scheduler = WanStepDistillScheduler(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
class MultiDistillModelStruct(MultiModelStruct): class MultiDistillModelStruct(MultiModelStruct):
...@@ -131,7 +130,6 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -131,7 +130,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.feature_caching == "NoCaching":
scheduler = Wan22StepDistillScheduler(self.config) self.scheduler = Wan22StepDistillScheduler(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
...@@ -191,10 +191,9 @@ class WanRunner(DefaultRunner): ...@@ -191,10 +191,9 @@ class WanRunner(DefaultRunner):
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
if self.config.get("changing_resolution", False): if self.config.get("changing_resolution", False):
scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config) self.scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
else: else:
scheduler = scheduler_class(self.config) self.scheduler = scheduler_class(self.config)
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img=None): def run_text_encoder(self, text, img=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
...@@ -206,13 +205,17 @@ class WanRunner(DefaultRunner): ...@@ -206,13 +205,17 @@ class WanRunner(DefaultRunner):
cfg_p_rank = dist.get_rank(cfg_p_group) cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0: if cfg_p_rank == 0:
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([text])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
text_encoder_output = {"context": context} text_encoder_output = {"context": context}
else: else:
context_null = self.text_encoders[0].infer([n_prompt]) context_null = self.text_encoders[0].infer([n_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
text_encoder_output = {"context_null": context_null} text_encoder_output = {"context_null": context_null}
else: else:
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([text])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
context_null = self.text_encoders[0].infer([n_prompt]) context_null = self.text_encoders[0].infer([n_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
text_encoder_output = { text_encoder_output = {
"context": context, "context": context,
"context_null": context_null, "context_null": context_null,
......
...@@ -20,8 +20,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -20,8 +20,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
super().__init__(config) super().__init__(config)
def init_scheduler(self): def init_scheduler(self):
scheduler = WanSkyreelsV2DFScheduler(self.config) self.scheduler = WanSkyreelsV2DFScheduler(self.config)
self.model.set_scheduler(scheduler)
def run_image_encoder(self, config, image_encoder, vae_model): def run_image_encoder(self, config, image_encoder, vae_model):
img = Image.open(config.image_path).convert("RGB") img = Image.open(config.image_path).convert("RGB")
...@@ -126,6 +125,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -126,6 +125,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
def run_pipeline(self): def run_pipeline(self):
self.init_scheduler() self.init_scheduler()
self.model.set_scheduler(self.scheduler)
self.run_input_encoder() self.run_input_encoder()
self.model.scheduler.prepare() self.model.scheduler.prepare()
output_video = self.run() output_video = self.run()
......
import functools
from typing import Dict, List, Optional
import torch
from loguru import logger
def compiled_method(compile_options: Optional[Dict] = None):
def decorator(func):
func_name = func.__name__
compile_opts = compile_options or {}
state = {
"original_func": func,
"compiled_graphs": {},
"compile_mode": False,
"selected_graph": None,
"selected_compiled": None,
}
@functools.wraps(func)
def wrapper(self, *args, graph_name: Optional[str] = None, **kwargs):
if state["compile_mode"]:
if graph_name is None:
graph_name = f"graph_{len(state['compiled_graphs']) + 1:02d}"
if graph_name not in state["compiled_graphs"]:
logger.info(f"[Compile] Compiling {func_name} as '{graph_name}'...")
compiled_func = torch.compile(state["original_func"], **compile_opts)
try:
result = compiled_func(self, *args, **kwargs)
state["compiled_graphs"][graph_name] = compiled_func
logger.info(f"[Compile] Compiled {func_name} as '{graph_name}'")
return result
except Exception as e:
logger.info(f"[Compile] Failed to compile {func_name} as '{graph_name}': {e}")
return state["original_func"](self, *args, **kwargs)
else:
logger.info(f"[Compile] Using existing compiled graph '{graph_name}'")
return state["compiled_graphs"][graph_name](self, *args, **kwargs)
elif state["selected_compiled"]:
return state["selected_compiled"](self, *args, **kwargs)
else:
return state["original_func"](self, *args, **kwargs)
def _enable_compile_mode():
logger.info(f"[Compile] Enabling compile mode for {func_name}")
state["compile_mode"] = True
def _disable_compile_mode():
logger.info(f"[Compile] Disabling compile mode for {func_name}")
state["compile_mode"] = False
def _select_graph(graph_name: str):
if graph_name not in state["compiled_graphs"]:
raise ValueError(f"Graph '{graph_name}' not found. Available graphs: {list(state['compiled_graphs'].keys())}")
logger.info(f"[Compile] Selecting graph '{graph_name}' for {func_name}")
state["selected_graph"] = graph_name
state["selected_compiled"] = state["compiled_graphs"][graph_name]
logger.info(f"[Compile] {func_name} will now use graph '{graph_name}' for inference")
def _unselect_graph():
logger.info(f"[Compile] Unselecting graph for {func_name}, returning to original function")
state["selected_graph"] = None
state["selected_compiled"] = None
def _get_status():
return {
"available_graphs": list(state["compiled_graphs"].keys()),
"compiled_count": len(state["compiled_graphs"]),
"selected_graph": state["selected_graph"],
"compile_mode": state["compile_mode"],
"mode": "compile" if state["compile_mode"] else ("inference" if state["selected_compiled"] else "original"),
}
def _clear_graphs():
state["compiled_graphs"].clear()
state["selected_graph"] = None
state["selected_compiled"] = None
state["compile_mode"] = False
logger.info(f"[Compile] Cleared all compiled graphs for {func_name}")
def _remove_graph(graph_name: str):
if graph_name in state["compiled_graphs"]:
del state["compiled_graphs"][graph_name]
if state["selected_graph"] == graph_name:
state["selected_graph"] = None
state["selected_compiled"] = None
logger.info(f"[Compile] Removed graph '{graph_name}' for {func_name}")
else:
logger.info(f"[Compile] Graph '{graph_name}' not found")
wrapper._enable_compile_mode = _enable_compile_mode
wrapper._disable_compile_mode = _disable_compile_mode
wrapper._select_graph = _select_graph
wrapper._unselect_graph = _unselect_graph
wrapper._get_status = _get_status
wrapper._clear_graphs = _clear_graphs
wrapper._remove_graph = _remove_graph
wrapper._func_name = func_name
return wrapper
return decorator
class CompiledMethodsMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._compiled_methods = {}
self._discover_compiled_methods()
def _discover_compiled_methods(self):
logger.info(f"[Compile] Discovering compiled methods for {self.__class__.__name__}...")
for attr_name in dir(self):
attr = getattr(self, attr_name)
if hasattr(attr, "_enable_compile_mode"):
logger.info(f"[Compile] Found compiled method: {attr_name}")
self._compiled_methods[attr_name] = attr
def enable_compile_mode(self, method_name: str = None):
if method_name:
if method_name not in self._compiled_methods:
raise ValueError(f"Method '{method_name}' is not a compiled method")
self._compiled_methods[method_name]._enable_compile_mode()
else:
for name, method in self._compiled_methods.items():
method._enable_compile_mode()
logger.info("[Compile] Enabled compile mode for all methods")
def disable_compile_mode(self, method_name: str = None):
if method_name:
if method_name not in self._compiled_methods:
raise ValueError(f"Method '{method_name}' is not a compiled method")
self._compiled_methods[method_name]._disable_compile_mode()
else:
for name, method in self._compiled_methods.items():
method._disable_compile_mode()
logger.info("[Compile] Disabled compile mode for all methods")
def select_graph(self, method_name: str, graph_name: str):
if method_name not in self._compiled_methods:
raise ValueError(f"Method '{method_name}' is not a compiled method")
method = self._compiled_methods[method_name]
method._select_graph(graph_name)
def unselect_graph(self, method_name: str):
if method_name not in self._compiled_methods:
raise ValueError(f"Method '{method_name}' is not a compiled method")
method = self._compiled_methods[method_name]
method._unselect_graph()
def get_compile_status(self) -> Dict:
status = {}
for method_name, method in self._compiled_methods.items():
status[method_name] = method._get_status()
return status
def get_compiled_methods(self) -> List[str]:
return list(self._compiled_methods.keys())
def clear_compiled_graphs(self, method_name: str = None):
if method_name:
if method_name in self._compiled_methods:
self._compiled_methods[method_name]._clear_graphs()
else:
logger.info(f"Method '{method_name}' not found")
else:
for method_name, method in self._compiled_methods.items():
method._clear_graphs()
logger.info("[Compile] Cleared all compiled graphs")
def remove_graph(self, method_name: str, graph_name: str):
if method_name not in self._compiled_methods:
raise ValueError(f"Method '{method_name}' is not a compiled method")
method = self._compiled_methods[method_name]
method._remove_graph(graph_name)
...@@ -22,12 +22,6 @@ def CHECK_PROFILING_DEBUG_LEVEL(target_level): ...@@ -22,12 +22,6 @@ def CHECK_PROFILING_DEBUG_LEVEL(target_level):
return current_level >= target_level return current_level >= target_level
@lru_cache(maxsize=None)
def CHECK_ENABLE_GRAPH_MODE():
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
return ENABLE_GRAPH_MODE
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def GET_RUNNING_FLAG(): def GET_RUNNING_FLAG():
RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer") RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer")
......
...@@ -26,6 +26,11 @@ def get_default_config(): ...@@ -26,6 +26,11 @@ def get_default_config():
"cfg_parallel": False, "cfg_parallel": False,
"enable_cfg": False, "enable_cfg": False,
"use_image_encoder": True, "use_image_encoder": True,
"lat_h": None,
"lat_w": None,
"tgt_h": None,
"tgt_w": None,
"target_shape": None,
} }
return default_config return default_config
......
...@@ -29,7 +29,6 @@ export DTYPE=BF16 ...@@ -29,7 +29,6 @@ export DTYPE=BF16
# Used for layers requiring higher precision # Used for layers requiring higher precision
# Available options: [FP32, None] # Available options: [FP32, None]
# If not set, default value: None (follows DTYPE setting) # If not set, default value: None (follows DTYPE setting)
# Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true.
export SENSITIVE_LAYER_DTYPE=FP32 export SENSITIVE_LAYER_DTYPE=FP32
# Performance Profiling Debug Level (Debug Only) # Performance Profiling Debug Level (Debug Only)
...@@ -39,14 +38,6 @@ export SENSITIVE_LAYER_DTYPE=FP32 ...@@ -39,14 +38,6 @@ export SENSITIVE_LAYER_DTYPE=FP32
# Note: This option can be set to 0 for production. # Note: This option can be set to 0 for production.
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
# Graph Mode Optimization (Performance Enhancement)
# Enables torch.compile for graph optimization, can improve inference performance
# Available options: [true, false]
# If not set, default value: false
# Note: First run may require compilation time, subsequent runs will be faster
# Note: When you use lightx2v as a service, you can set this option to true.
export ENABLE_GRAPH_MODE=true
echo "===============================================================================" echo "==============================================================================="
echo "LightX2V Base Environment Variables Summary:" echo "LightX2V Base Environment Variables Summary:"
...@@ -57,5 +48,4 @@ echo "-------------------------------------------------------------------------- ...@@ -57,5 +48,4 @@ echo "--------------------------------------------------------------------------
echo "Model Inference Data Type: ${DTYPE}" echo "Model Inference Data Type: ${DTYPE}"
echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}" echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}"
echo "Performance Profiling Debug Level: ${PROFILING_DEBUG_LEVEL}" echo "Performance Profiling Debug Level: ${PROFILING_DEBUG_LEVEL}"
echo "Graph Mode Optimization: ${ENABLE_GRAPH_MODE}"
echo "===============================================================================" echo "==============================================================================="
...@@ -28,7 +28,6 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,6 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16 export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32 export SENSITIVE_LAYER_DTYPE=FP32
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
......
...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
......
...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
......
...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
......
...@@ -28,7 +28,6 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,6 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16 export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32 export SENSITIVE_LAYER_DTYPE=FP32
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
...@@ -28,7 +28,6 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -28,7 +28,6 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16 export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32 export SENSITIVE_LAYER_DTYPE=FP32
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1_distill \ --model_cls wan2.1_distill \
......
...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
......
...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,6 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2 export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment