Commit 2a31ba43 authored by helloyongyang's avatar helloyongyang
Browse files

Add dist warmup code

parent 8c1d91e1
import json import json
import os import os
import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from torch.distributed.tensor.device_mesh import init_device_mesh from torch.distributed.tensor.device_mesh import init_device_mesh
...@@ -91,6 +92,9 @@ def set_parallel_config(config): ...@@ -91,6 +92,9 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1: if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1:
config["cfg_parallel"] = True config["cfg_parallel"] = True
# warmup dist
_a = torch.zeros([1]).to(f"cuda:{dist.get_rank()}")
dist.all_reduce(_a)
def print_config(config): def print_config(config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment