Commit e9e33065 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Dev distill (#69)

* add step & cfg distillation wan model
parent 497ff9fe
......@@ -53,7 +53,7 @@ class DiTRunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.model = self.runner.load_transformer(self.runner.get_init_device())
self.runner.model = self.runner.load_transformer()
def _run_dit(self, inputs, kwargs):
self.runner.config.update(tensor_transporter.load_tensor(kwargs))
......
......@@ -51,7 +51,7 @@ class ImageEncoderRunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.image_encoder = self.runner.load_image_encoder(self.runner.get_init_device())
self.runner.image_encoder = self.runner.load_image_encoder()
def _run_image_encoder(self, img):
img = image_transporter.load_image(img)
......
......@@ -53,7 +53,7 @@ class TextEncoderRunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.text_encoders = self.runner.load_text_encoder(self.runner.get_init_device())
self.runner.text_encoders = self.runner.load_text_encoder()
def _run_text_encoder(self, text, img, n_prompt):
if img is not None:
......
......@@ -56,7 +56,7 @@ class VAERunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config)
self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae(self.runner.get_init_device())
self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae()
def _run_vae_encoder(self, img):
img = image_transporter.load_image(img)
......
......@@ -16,19 +16,19 @@ class CogvideoxRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self, init_device):
def load_transformer(self):
model = CogvideoxModel(self.config)
return model
def load_image_encoder(self, init_device):
def load_image_encoder(self):
return None
def load_text_encoder(self, init_device):
def load_text_encoder(self):
text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder]
return text_encoders
def load_vae(self, init_device):
def load_vae(self):
vae_model = CogvideoxVAE(self.config)
return vae_model, vae_model
......
......@@ -24,10 +24,12 @@ class DefaultRunner:
if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if not self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = False
self.set_init_device()
def init_modules(self):
logger.info("Initializing runner modules...")
self.set_init_device()
if self.config["mode"] == "split_server":
self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter()
......@@ -93,6 +95,7 @@ class DefaultRunner:
def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = False
if self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
......
......@@ -29,8 +29,8 @@ class WanCausVidRunner(WanRunner):
self.infer_blocks = self.model.config.num_blocks
self.num_fragments = self.model.config.num_fragments
def load_transformer(self, init_device):
return WanCausVidModel(self.config.model_path, self.config, init_device)
def load_transformer(self):
return WanCausVidModel(self.config.model_path, self.config, self.init_device)
def set_inputs(self, inputs):
super().set_inputs(inputs)
......
......@@ -23,8 +23,8 @@ class WanDistillRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self, init_device):
model = WanDistillModel(self.config.model_path, self.config, init_device)
def load_transformer(self):
model = WanDistillModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
......
#!/bin/bash
# set path and first
lightx2v_path="/data/lightx2v-dev/"
model_path="/data/lightx2v-dev/Wan2.1-T2V-14B/"
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment