Commit 8b213df0 authored by GoatWu's avatar GoatWu
Browse files

bugs fiexed for distill_model server

parent 6ac3cee7
...@@ -25,3 +25,4 @@ ...@@ -25,3 +25,4 @@
build/ build/
dist/ dist/
.cache/ .cache/
server_cache/
...@@ -36,6 +36,7 @@ def main(): ...@@ -36,6 +36,7 @@ def main():
choices=[ choices=[
"wan2.1", "wan2.1",
"hunyuan", "hunyuan",
"wan2.1_distill",
"wan2.1_causvid", "wan2.1_causvid",
"wan2.1_skyreels_v2_df", "wan2.1_skyreels_v2_df",
"wan2.1_audio", "wan2.1_audio",
...@@ -55,7 +56,7 @@ def main(): ...@@ -55,7 +56,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
cache_dir = Path(__file__).parent.parent / ".cache" cache_dir = Path(__file__).parent.parent / "server_cache"
inference_service = DistributedInferenceService() inference_service = DistributedInferenceService()
api_server = ApiServer() api_server = ApiServer()
......
...@@ -9,7 +9,7 @@ class WanStepDistillScheduler(WanScheduler): ...@@ -9,7 +9,7 @@ class WanStepDistillScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.denoising_step_list = config.denoising_step_list self.denoising_step_list = config.denoising_step_list
self.infer_steps = self.config.infer_steps self.infer_steps = len(self.denoising_step_list)
self.sample_shift = self.config.sample_shift self.sample_shift = self.config.sample_shift
def prepare(self, image_encoder_output): def prepare(self, image_encoder_output):
...@@ -40,10 +40,7 @@ class WanStepDistillScheduler(WanScheduler): ...@@ -40,10 +40,7 @@ class WanStepDistillScheduler(WanScheduler):
self.sigma_min = self.sigmas[-1].item() self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item() self.sigma_max = self.sigmas[0].item()
if len(self.denoising_step_list) == self.infer_steps: # 如果denoising_step_list有效既使用
self.set_denoising_timesteps(device=self.device) self.set_denoising_timesteps(device=self.device)
else:
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def set_denoising_timesteps(self, device: Union[str, torch.device] = None): def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64) self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64)
......
...@@ -46,12 +46,26 @@ def generate_task_id(): ...@@ -46,12 +46,26 @@ def generate_task_id():
def post_all_tasks(urls, messages): def post_all_tasks(urls, messages):
msg_num = len(messages) msg_num = len(messages)
msg_index = 0 msg_index = 0
while True: available_urls = []
for url in urls: for url in urls:
response = requests.get(f"{url}/v1/local/video/generate/service_status").json() try:
_ = requests.get(f"{url}/v1/service/status").json()
except Exception as e:
continue
available_urls.append(url)
if not available_urls:
logger.error("No available urls.")
return
logger.info(f"available_urls: {available_urls}")
while True:
for url in available_urls:
response = requests.get(f"{url}/v1/service/status").json()
if response["service_status"] == "idle": if response["service_status"] == "idle":
logger.info(f"{url} service is idle, start task...") logger.info(f"{url} service is idle, start task...")
response = requests.post(f"{url}/v1/local/video/generate", json=messages[msg_index]) response = requests.post(f"{url}/v1/tasks/", json=messages[msg_index])
logger.info(f"response: {response.json()}") logger.info(f"response: {response.json()}")
msg_index += 1 msg_index += 1
if msg_index == msg_num: if msg_index == msg_num:
......
...@@ -42,7 +42,6 @@ python -m lightx2v.api_server \ ...@@ -42,7 +42,6 @@ python -m lightx2v.api_server \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan/wan_i2v_dist.json \ --config_json ${lightx2v_path}/configs/wan/wan_i2v_dist.json \
--port 8000 \ --port 8000 \
--start_inference \
--nproc_per_node 1 --nproc_per_node 1
echo "Service stopped" echo "Service stopped"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment