Commit 88b7a2dd authored by helloyongyang's avatar helloyongyang
Browse files

Support cfg parallel for T5 model

parent 2e5794c7
import argparse import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
import json import json
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config, print_config
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
...@@ -26,15 +25,6 @@ from loguru import logger ...@@ -26,15 +25,6 @@ from loguru import logger
def init_runner(config): def init_runner(config):
seed_all(config.seed) seed_all(config.seed)
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if CHECK_ENABLE_GRAPH_MODE(): if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config) default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner) runner = GraphRunner(default_runner)
...@@ -73,7 +63,7 @@ def main(): ...@@ -73,7 +63,7 @@ def main():
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
config = set_config(args) config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") print_config(config)
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
......
...@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -367,7 +367,7 @@ class WanTransformerInfer(BaseTransformerInfer):
del freqs_i, norm1_out, norm1_weight, norm1_bias del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1: if self.config["seq_parallel"]:
attn_out = weights.self_attn_1_parallel.apply( attn_out = weights.self_attn_1_parallel.apply(
q=q, q=q,
k=k, k=k,
......
...@@ -70,7 +70,7 @@ class WanModel: ...@@ -70,7 +70,7 @@ class WanModel:
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1: if self.config["seq_parallel"]:
self.transformer_infer_class = WanTransformerDistInfer self.transformer_infer_class = WanTransformerDistInfer
else: else:
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
...@@ -187,7 +187,7 @@ class WanModel: ...@@ -187,7 +187,7 @@ class WanModel:
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
if self.config["enable_cfg"] and self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1: if self.config["cfg_parallel"]:
self.infer_func = self.infer_with_cfg_parallel self.infer_func = self.infer_with_cfg_parallel
else: else:
self.infer_func = self.infer_wo_cfg_parallel self.infer_func = self.infer_wo_cfg_parallel
......
...@@ -191,7 +191,7 @@ class WanSelfAttention(WeightModule): ...@@ -191,7 +191,7 @@ class WanSelfAttention(WeightModule):
else: else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]()) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config.parallel and self.config.parallel.get("seq_p_size", False) and self.config.parallel.seq_p_size > 1: if self.config["seq_parallel"]:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config.parallel.get("seq_p_attn_type", "ulysses")]()) self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config.parallel.get("seq_p_attn_type", "ulysses")]())
if self.quant_method in ["advanced_ptq"]: if self.quant_method in ["advanced_ptq"]:
......
...@@ -176,16 +176,30 @@ class WanRunner(DefaultRunner): ...@@ -176,16 +176,30 @@ class WanRunner(DefaultRunner):
def run_text_encoder(self, text, img): def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""]) if self.config["cfg_parallel"]:
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
context = self.text_encoders[0].infer([text])
text_encoder_output = {"context": context}
else:
context_null = self.text_encoders[0].infer([n_prompt])
text_encoder_output = {"context_null": context_null}
else:
context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt])
text_encoder_output = {
"context": context,
"context_null": context_null,
}
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0] del self.text_encoders[0]
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
def run_image_encoder(self, img): def run_image_encoder(self, img):
......
...@@ -2,6 +2,8 @@ import json ...@@ -2,6 +2,8 @@ import json
import os import os
from easydict import EasyDict from easydict import EasyDict
from loguru import logger from loguru import logger
import torch.distributed as dist
from torch.distributed.tensor.device_mesh import init_device_mesh
def get_default_config(): def get_default_config():
...@@ -19,6 +21,7 @@ def get_default_config(): ...@@ -19,6 +21,7 @@ def get_default_config():
"mm_config": {}, "mm_config": {},
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"parallel": False, "parallel": False,
"enable_cfg": False,
} }
return default_config return default_config
...@@ -57,4 +60,31 @@ def set_config(args): ...@@ -57,4 +60,31 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.") logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1 config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
set_parallel_config(config) # parallel config
return config return config
def set_parallel_config(config):
config["seq_parallel"] = False
config["cfg_parallel"] = False
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if config.parallel and config.parallel.get("seq_p_size", False) and config.parallel.seq_p_size > 1:
config["seq_parallel"] = True
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None) # Remove device_mesh if it exists
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment