Commit 32fd1c52 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

deploy update (#309)



deploy update

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarqinxinyi <qinxinyi@sensetime.com>
parent f6e214bb
......@@ -31,6 +31,8 @@ class BaseWorker:
if config.parallel:
self.rank = dist.get_rank()
set_parallel_config(config)
seed_all(config.seed)
torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config.model_cls](config)
# fixed config
self.fixed_config = copy.deepcopy(self.runner.config)
......@@ -189,15 +191,6 @@ class PipelineWorker(BaseWorker):
self.runner.init_modules()
self.run_func = self.runner.run_pipeline
def init_temp_params(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(cur_dir, "../../.."))
self.runner.config["prompt"] = "The video features a old lady is saying something and knitting a sweater."
if self.runner.config.task == "i2v":
self.runner.config["image_path"] = os.path.join(base_dir, "assets", "inputs", "audio", "15.png")
if self.is_audio_model():
self.runner.config["audio_path"] = os.path.join(base_dir, "assets", "inputs", "audio", "15.wav")
@class_try_catch_async_with_thread
async def run(self, inputs, outputs, params, data_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
......@@ -323,7 +316,7 @@ class DiTWorker(BaseWorker):
future = asyncio.Future()
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank)
self.thread.start()
status, (out, _) = await future
status, out = await future
if not status:
return False
......
......@@ -141,7 +141,7 @@ class DefaultRunner(BaseRunner):
def end_run(self):
self.model.scheduler.clear()
del self.inputs, self.model.scheduler
del self.inputs
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear()
......
......@@ -481,7 +481,8 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx
if audio_array is not None:
self.segment = AudioSegment(audio_array, 0, audio_array.shape[0])
end_idx = audio_array.shape[0] // self._audio_processor.audio_frame_rate - self.prev_frame_length
self.segment = AudioSegment(audio_array, 0, end_idx)
else:
self.segment = self.inputs["audio_segments"][segment_idx]
......@@ -584,6 +585,8 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_reader.start()
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1
......@@ -611,7 +614,7 @@ class WanAudioRunner(WanRunner): # type:ignore
segment_idx += 1
finally:
if hasattr(self.model, "scheduler"):
if hasattr(self.model, "inputs"):
self.end_run()
if self.va_reader:
self.va_reader.stop()
......
......@@ -379,7 +379,7 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
cpu_weight_dict = {}
if is_weight_loader:
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = load_pt_safetensors(checkpoint_path)
cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment