Commit e684202c authored by helloyongyang's avatar helloyongyang
Browse files

fix bug

parent 8e7490b9
......@@ -63,7 +63,7 @@ def main():
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if "parallel" in config:
if config.parallel:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
set_parallel_config(config)
......
......@@ -22,6 +22,8 @@ def get_default_config():
"mm_config": {},
"use_prompt_enhancer": False,
"parallel": False,
"seq_parallel": False,
"cfg_parallel": False,
"enable_cfg": False,
}
return default_config
......@@ -65,8 +67,6 @@ def set_config(args):
def set_parallel_config(config):
config["seq_parallel"] = False
config["cfg_parallel"] = False
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment