Commit 32fd1c52 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

deploy update (#309)



deploy update

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarqinxinyi <qinxinyi@sensetime.com>
parent f6e214bb
...@@ -31,6 +31,8 @@ class BaseWorker: ...@@ -31,6 +31,8 @@ class BaseWorker:
if config.parallel: if config.parallel:
self.rank = dist.get_rank() self.rank = dist.get_rank()
set_parallel_config(config) set_parallel_config(config)
seed_all(config.seed)
torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config.model_cls](config) self.runner = RUNNER_REGISTER[config.model_cls](config)
# fixed config # fixed config
self.fixed_config = copy.deepcopy(self.runner.config) self.fixed_config = copy.deepcopy(self.runner.config)
...@@ -189,15 +191,6 @@ class PipelineWorker(BaseWorker): ...@@ -189,15 +191,6 @@ class PipelineWorker(BaseWorker):
self.runner.init_modules() self.runner.init_modules()
self.run_func = self.runner.run_pipeline self.run_func = self.runner.run_pipeline
def init_temp_params(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(cur_dir, "../../.."))
self.runner.config["prompt"] = "The video features a old lady is saying something and knitting a sweater."
if self.runner.config.task == "i2v":
self.runner.config["image_path"] = os.path.join(base_dir, "assets", "inputs", "audio", "15.png")
if self.is_audio_model():
self.runner.config["audio_path"] = os.path.join(base_dir, "assets", "inputs", "audio", "15.wav")
@class_try_catch_async_with_thread @class_try_catch_async_with_thread
async def run(self, inputs, outputs, params, data_manager): async def run(self, inputs, outputs, params, data_manager):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -323,7 +316,7 @@ class DiTWorker(BaseWorker): ...@@ -323,7 +316,7 @@ class DiTWorker(BaseWorker):
future = asyncio.Future() future = asyncio.Future()
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank) self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank)
self.thread.start() self.thread.start()
status, (out, _) = await future status, out = await future
if not status: if not status:
return False return False
......
...@@ -141,7 +141,7 @@ class DefaultRunner(BaseRunner): ...@@ -141,7 +141,7 @@ class DefaultRunner(BaseRunner):
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
del self.inputs, self.model.scheduler del self.inputs
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"): if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear() self.model.transformer_infer.weights_stream_mgr.clear()
......
...@@ -481,7 +481,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -481,7 +481,8 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_run_segment(self, segment_idx, audio_array=None): def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx self.segment_idx = segment_idx
if audio_array is not None: if audio_array is not None:
self.segment = AudioSegment(audio_array, 0, audio_array.shape[0]) end_idx = audio_array.shape[0] // self._audio_processor.audio_frame_rate - self.prev_frame_length
self.segment = AudioSegment(audio_array, 0, end_idx)
else: else:
self.segment = self.inputs["audio_segments"][segment_idx] self.segment = self.inputs["audio_segments"][segment_idx]
...@@ -584,6 +585,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -584,6 +585,8 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_reader.start() self.va_reader.start()
self.init_run() self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
self.video_segment_num = "unlimited" self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1 fetch_timeout = self.va_reader.segment_duration + 1
...@@ -611,7 +614,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -611,7 +614,7 @@ class WanAudioRunner(WanRunner): # type:ignore
segment_idx += 1 segment_idx += 1
finally: finally:
if hasattr(self.model, "scheduler"): if hasattr(self.model, "inputs"):
self.end_run() self.end_run()
if self.va_reader: if self.va_reader:
self.va_reader.stop() self.va_reader.stop()
......
...@@ -379,7 +379,7 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -379,7 +379,7 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
cpu_weight_dict = {} cpu_weight_dict = {}
if is_weight_loader: if is_weight_loader:
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = load_pt_safetensors(checkpoint_path) cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key)
for key in list(cpu_weight_dict.keys()): for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key: if remove_key and remove_key in key:
cpu_weight_dict.pop(key) cpu_weight_dict.pop(key)
......
...@@ -28,3 +28,6 @@ fastapi ...@@ -28,3 +28,6 @@ fastapi
uvicorn uvicorn
PyJWT PyJWT
requests requests
alibabacloud_dypnsapi20170525==1.2.2
redis==6.4.0
tos
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment