Commit 88b7a2dd authored by helloyongyang's avatar helloyongyang
Browse files

Support cfg parallel for T5 model

parent 2e5794c7
import argparse
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
import json
from lightx2v.utils.envs import *
from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.utils.set_config import set_config, print_config
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
......@@ -26,15 +25,6 @@ from loguru import logger
def init_runner(config):
seed_all(config.seed)
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner)
......@@ -73,7 +63,7 @@ def main():
with ProfilingContext("Total Cost"):
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
print_config(config)
runner = init_runner(config)
runner.run_pipeline()
......
......@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer):
del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1:
if self.config["seq_parallel"]:
attn_out = weights.self_attn_1_parallel.apply(
q=q,
k=k,
......
......@@ -70,7 +70,7 @@ class WanModel:
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1:
if self.config["seq_parallel"]:
self.transformer_infer_class = WanTransformerDistInfer
else:
if self.config["feature_caching"] == "NoCaching":
......@@ -187,7 +187,7 @@ class WanModel:
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if self.config["enable_cfg"] and self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1:
if self.config["cfg_parallel"]:
self.infer_func = self.infer_with_cfg_parallel
else:
self.infer_func = self.infer_wo_cfg_parallel
......
......@@ -191,7 +191,7 @@ class WanSelfAttention(WeightModule):
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1:
if self.config["seq_parallel"]:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config.parallel.get("seq_p_attn_type", "ulysses")]())
if self.quant_method in ["advanced_ptq"]:
......
......@@ -176,16 +176,30 @@ class WanRunner(DefaultRunner):
def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
if self.config["cfg_parallel"]:
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
context = self.text_encoders[0].infer([text])
text_encoder_output = {"context": context}
else:
context_null = self.text_encoders[0].infer([n_prompt])
text_encoder_output = {"context_null": context_null}
else:
context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt])
text_encoder_output = {
"context": context,
"context_null": context_null,
}
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0]
torch.cuda.empty_cache()
gc.collect()
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output
def run_image_encoder(self, img):
......
......@@ -2,6 +2,8 @@ import json
import os
from easydict import EasyDict
from loguru import logger
import torch.distributed as dist
from torch.distributed.tensor.device_mesh import init_device_mesh
def get_default_config():
......@@ -19,6 +21,7 @@ def get_default_config():
"mm_config": {},
"use_prompt_enhancer": False,
"parallel": False,
"enable_cfg": False,
}
return default_config
......@@ -57,4 +60,31 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
set_parallel_config(config) # parallel config
return config
def set_parallel_config(config):
config["seq_parallel"] = False
config["cfg_parallel"] = False
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if config.parallel and config.parallel.get("seq_p_size", False) and config.parallel.seq_p_size > 1:
config["seq_parallel"] = True
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None) # Remove device_mesh if it exists
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment