Commit e9e33065 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Dev distill (#69)

* add step & cfg distillation wan model
parent 497ff9fe
...@@ -53,7 +53,7 @@ class DiTRunner: ...@@ -53,7 +53,7 @@ class DiTRunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls] self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config) self.runner = self.runner_cls(config)
self.runner.model = self.runner.load_transformer(self.runner.get_init_device()) self.runner.model = self.runner.load_transformer()
def _run_dit(self, inputs, kwargs): def _run_dit(self, inputs, kwargs):
self.runner.config.update(tensor_transporter.load_tensor(kwargs)) self.runner.config.update(tensor_transporter.load_tensor(kwargs))
......
...@@ -51,7 +51,7 @@ class ImageEncoderRunner: ...@@ -51,7 +51,7 @@ class ImageEncoderRunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls] self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config) self.runner = self.runner_cls(config)
self.runner.image_encoder = self.runner.load_image_encoder(self.runner.get_init_device()) self.runner.image_encoder = self.runner.load_image_encoder()
def _run_image_encoder(self, img): def _run_image_encoder(self, img):
img = image_transporter.load_image(img) img = image_transporter.load_image(img)
......
...@@ -53,7 +53,7 @@ class TextEncoderRunner: ...@@ -53,7 +53,7 @@ class TextEncoderRunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls] self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config) self.runner = self.runner_cls(config)
self.runner.text_encoders = self.runner.load_text_encoder(self.runner.get_init_device()) self.runner.text_encoders = self.runner.load_text_encoder()
def _run_text_encoder(self, text, img, n_prompt): def _run_text_encoder(self, text, img, n_prompt):
if img is not None: if img is not None:
......
...@@ -56,7 +56,7 @@ class VAERunner: ...@@ -56,7 +56,7 @@ class VAERunner:
self.runner_cls = RUNNER_REGISTER[self.config.model_cls] self.runner_cls = RUNNER_REGISTER[self.config.model_cls]
self.runner = self.runner_cls(config) self.runner = self.runner_cls(config)
self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae(self.runner.get_init_device()) self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae()
def _run_vae_encoder(self, img): def _run_vae_encoder(self, img):
img = image_transporter.load_image(img) img = image_transporter.load_image(img)
......
...@@ -16,19 +16,19 @@ class CogvideoxRunner(DefaultRunner): ...@@ -16,19 +16,19 @@ class CogvideoxRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def load_transformer(self, init_device): def load_transformer(self):
model = CogvideoxModel(self.config) model = CogvideoxModel(self.config)
return model return model
def load_image_encoder(self, init_device): def load_image_encoder(self):
return None return None
def load_text_encoder(self, init_device): def load_text_encoder(self):
text_encoder = T5EncoderModel_v1_1_xxl(self.config) text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder] text_encoders = [text_encoder]
return text_encoders return text_encoders
def load_vae(self, init_device): def load_vae(self):
vae_model = CogvideoxVAE(self.config) vae_model = CogvideoxVAE(self.config)
return vae_model, vae_model return vae_model, vae_model
......
...@@ -24,10 +24,12 @@ class DefaultRunner: ...@@ -24,10 +24,12 @@ class DefaultRunner:
if not self.check_sub_servers("prompt_enhancer"): if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.") logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if not self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = False
self.set_init_device()
def init_modules(self): def init_modules(self):
logger.info("Initializing runner modules...") logger.info("Initializing runner modules...")
self.set_init_device()
if self.config["mode"] == "split_server": if self.config["mode"] == "split_server":
self.tensor_transporter = TensorTransporter() self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter() self.image_transporter = ImageTransporter()
...@@ -93,6 +95,7 @@ class DefaultRunner: ...@@ -93,6 +95,7 @@ class DefaultRunner:
def set_inputs(self, inputs): def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "") self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = False
if self.has_prompt_enhancer: if self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side. self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "") self.config["negative_prompt"] = inputs.get("negative_prompt", "")
......
...@@ -29,8 +29,8 @@ class WanCausVidRunner(WanRunner): ...@@ -29,8 +29,8 @@ class WanCausVidRunner(WanRunner):
self.infer_blocks = self.model.config.num_blocks self.infer_blocks = self.model.config.num_blocks
self.num_fragments = self.model.config.num_fragments self.num_fragments = self.model.config.num_fragments
def load_transformer(self, init_device): def load_transformer(self):
return WanCausVidModel(self.config.model_path, self.config, init_device) return WanCausVidModel(self.config.model_path, self.config, self.init_device)
def set_inputs(self, inputs): def set_inputs(self, inputs):
super().set_inputs(inputs) super().set_inputs(inputs)
......
...@@ -23,8 +23,8 @@ class WanDistillRunner(WanRunner): ...@@ -23,8 +23,8 @@ class WanDistillRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def load_transformer(self, init_device): def load_transformer(self):
model = WanDistillModel(self.config.model_path, self.config, init_device) model = WanDistillModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path: if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path) lora_name = lora_wrapper.load_lora(self.config.lora_path)
......
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path="/data/lightx2v-dev/" lightx2v_path=
model_path="/data/lightx2v-dev/Wan2.1-T2V-14B/" model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment