Commit e684202c authored by helloyongyang's avatar helloyongyang
Browse files

fix bug

parent 8e7490b9
...@@ -63,7 +63,7 @@ def main(): ...@@ -63,7 +63,7 @@ def main():
config = set_config(args) config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if "parallel" in config: if config.parallel:
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
set_parallel_config(config) set_parallel_config(config)
......
...@@ -22,6 +22,8 @@ def get_default_config(): ...@@ -22,6 +22,8 @@ def get_default_config():
"mm_config": {}, "mm_config": {},
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"parallel": False, "parallel": False,
"seq_parallel": False,
"cfg_parallel": False,
"enable_cfg": False, "enable_cfg": False,
} }
return default_config return default_config
...@@ -65,8 +67,6 @@ def set_config(args): ...@@ -65,8 +67,6 @@ def set_config(args):
def set_parallel_config(config): def set_parallel_config(config):
config["seq_parallel"] = False
config["cfg_parallel"] = False
if config.parallel: if config.parallel:
if not dist.is_initialized(): if not dist.is_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment