Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import random import random
import socket import socket
from collections import Counter from collections import Counter
from threading import local
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -95,8 +94,9 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -95,8 +94,9 @@ class ParallelContext(metaclass=SingletonMeta):
@staticmethod @staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode): def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode), \ assert isinstance(
f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' parallel_mode, ParallelMode
), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}"
def get_global_rank(self): def get_global_rank(self):
"""Returns the global rank of the current device. """Returns the global rank of the current device.
...@@ -239,8 +239,10 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -239,8 +239,10 @@ class ParallelContext(metaclass=SingletonMeta):
def is_pipeline_last_stage(self, ignore_virtual=False): def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual: if not ignore_virtual:
if self.virtual_pipeline_parallel_size \ if (
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: self.virtual_pipeline_parallel_size is not None
and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1
):
return False return False
return self.is_last_rank(ParallelMode.PIPELINE) return self.is_last_rank(ParallelMode.PIPELINE)
...@@ -371,12 +373,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -371,12 +373,12 @@ class ParallelContext(metaclass=SingletonMeta):
port (str): the master port for distributed training port (str): the master port for distributed training
""" """
# initialize the default process group # initialize the default process group
init_method = f'tcp://[{host}]:{port}' init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# None will give the default global process group for pytorch dist operations # None will give the default global process group for pytorch dist operations
ranks = list(range(world_size)) ranks = list(range(world_size))
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank) self.add_global_rank(ParallelMode.GLOBAL, rank)
...@@ -398,10 +400,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -398,10 +400,11 @@ class ParallelContext(metaclass=SingletonMeta):
pps = self.pipeline_parallel_size pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size tps = self.tensor_parallel_size
ws = self.world_size ws = self.world_size
assert ws == dps * pps * \ assert ws == dps * pps * tps, (
tps, f"Expected the world size {ws} to be equal to data" \ f"Expected the world size {ws} to be equal to data"
f" parallel size ({dps}) * pipeline parallel size " \ f" parallel size ({dps}) * pipeline parallel size "
f"({pps}) * tensor parallel size ({tps})" f"({pps}) * tensor parallel size ({tps})"
)
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config: if key in config:
...@@ -409,10 +412,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -409,10 +412,11 @@ class ParallelContext(metaclass=SingletonMeta):
if isinstance(ele, int): if isinstance(ele, int):
setattr(self, attr_name, ele) setattr(self, attr_name, ele)
elif isinstance(ele, dict): elif isinstance(ele, dict):
setattr(self, attr_name, ele['size']) setattr(self, attr_name, ele["size"])
else: else:
raise NotImplementedError( raise NotImplementedError(
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') f'{"Parallel configuration does not support this kind of argument, please use int or dict"}'
)
def init_parallel_groups(self): def init_parallel_groups(self):
"""Initializes the parallel groups. """Initializes the parallel groups.
...@@ -427,10 +431,10 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -427,10 +431,10 @@ class ParallelContext(metaclass=SingletonMeta):
self.world_size = world_size self.world_size = world_size
# set parallel size as attributes for global context # set parallel size as attributes for global context
parallel_config = self.config.get('parallel', None) parallel_config = self.config.get("parallel", None)
if parallel_config is not None: if parallel_config is not None:
self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size') self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size') self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
# the user should not set the data parallel size manually # the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config # instead, it should be calculated based on other parallel config
...@@ -438,33 +442,33 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -438,33 +442,33 @@ class ParallelContext(metaclass=SingletonMeta):
# get the tensor parallel mode and check # get the tensor parallel mode and check
tensor_parallel_mode = None tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in \ if parallel_config is not None and "tensor" in parallel_config and "mode" in parallel_config["tensor"]:
parallel_config and 'mode' in parallel_config['tensor']: tensor_parallel_mode = parallel_config["tensor"]["mode"]
tensor_parallel_mode = parallel_config['tensor']['mode'] assert (
assert tensor_parallel_mode in ALLOWED_MODES, \ tensor_parallel_mode in ALLOWED_MODES
f"mode in the parallel config must be set to one of {ALLOWED_MODES}" ), f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode env.mode = tensor_parallel_mode
self.check_sanity() self.check_sanity()
pg_init = [] pg_init = []
# LSG: init data parallel process group for compatibility with other parallel module such as zero # LSG: init data parallel process group for compatibility with other parallel module such as zero
pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) pg_init.append(dict(type=INITIALIZER_MAPPING["data"]))
# LSG: init model parallel process group for compatibility with amp and clip grad # LSG: init model parallel process group for compatibility with amp and clip grad
pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) pg_init.append(dict(type=INITIALIZER_MAPPING["model"]))
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) pg_init.append(dict(type=INITIALIZER_MAPPING["pipeline"]))
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) pg_init.append(dict(type=INITIALIZER_MAPPING["tensor"]))
# init specific tensor parallel group # init specific tensor parallel group
if tensor_parallel_mode is not None: if tensor_parallel_mode is not None:
tensor_parallel_cfg = parallel_config['tensor'].copy() tensor_parallel_cfg = parallel_config["tensor"].copy()
# remove duplicate parameters # remove duplicate parameters
tensor_parallel_cfg.pop('mode') tensor_parallel_cfg.pop("mode")
tensor_parallel_cfg.pop('size') tensor_parallel_cfg.pop("size")
# add this config to initialize later # add this config to initialize later
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
...@@ -472,11 +476,16 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -472,11 +476,16 @@ class ParallelContext(metaclass=SingletonMeta):
# run initialization of different process groups # run initialization of different process groups
for initializer_cfg in pg_init: for initializer_cfg in pg_init:
cfg = initializer_cfg.copy() cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type') initializer_type = cfg.pop("type")
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
self.data_parallel_size, rank,
self.pipeline_parallel_size, world_size,
self.tensor_parallel_size, **cfg) self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
**cfg,
)
parallel_setting = initializer.init_dist_group() parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list): if isinstance(parallel_setting, list):
for args in parallel_setting: for args in parallel_setting:
...@@ -497,8 +506,7 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -497,8 +506,7 @@ class ParallelContext(metaclass=SingletonMeta):
return parallel_mode in self._groups return parallel_mode in self._groups
def destroy(self): def destroy(self):
"""Destroys the current distributed parallel environment. """Destroys the current distributed parallel environment."""
"""
for mode, group in self._groups.items(): for mode, group in self._groups.items():
if mode is not ParallelMode.GLOBAL: if mode is not ParallelMode.GLOBAL:
dist.destroy_process_group(group) dist.destroy_process_group(group)
...@@ -519,7 +527,7 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -519,7 +527,7 @@ class ParallelContext(metaclass=SingletonMeta):
torch.cuda.set_device(device_ordinal) torch.cuda.set_device(device_ordinal)
if self._verbose: if self._verbose:
self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}') self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}")
def set_seed(self, seed: int): def set_seed(self, seed: int):
"""Sets seeds for all random libraries. """Sets seeds for all random libraries.
...@@ -552,21 +560,25 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -552,21 +560,25 @@ class ParallelContext(metaclass=SingletonMeta):
set_mode(ParallelMode.DATA) set_mode(ParallelMode.DATA)
seeds = get_seeds() seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
if self._verbose: if self._verbose:
self._logger.info(f"initialized seed on rank {global_rank}, " self._logger.info(
f"numpy: {seed}, python random: {seed}, {seed_str}," f"initialized seed on rank {global_rank}, "
f"the default parallel seed is {ParallelMode.DATA}.") f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}."
)
else: else:
if self._verbose: if self._verbose:
self._logger.info( self._logger.info(
f"initialized seed on rank {global_rank}, " f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}", f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0]) ranks=[0],
)
self._logger.info( self._logger.info(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states",
ranks=[0]) ranks=[0],
)
def set_virtual_pipeline_parallel_size(self, size): def set_virtual_pipeline_parallel_size(self, size):
self.virtual_pipeline_parallel_size = size self.virtual_pipeline_parallel_size = size
......
...@@ -6,44 +6,43 @@ from enum import Enum ...@@ -6,44 +6,43 @@ from enum import Enum
# parallel modes # parallel modes
class ParallelMode(Enum): class ParallelMode(Enum):
"""This is an enumeration class containing all possible parallel modes. """This is an enumeration class containing all possible parallel modes."""
"""
GLOBAL = 'global' GLOBAL = "global"
# common parallel # common parallel
DATA = 'data' DATA = "data"
# model parallel - containing tensor and pipeline parallel groups # model parallel - containing tensor and pipeline parallel groups
# this is added to facilitate amp and grad clipping in hybrid parallel # this is added to facilitate amp and grad clipping in hybrid parallel
MODEL = 'model' MODEL = "model"
# pipeline parallel # pipeline parallel
PIPELINE = 'pipe' PIPELINE = "pipe"
# containing all ranks in tensor parallel # containing all ranks in tensor parallel
TENSOR = 'tensor' TENSOR = "tensor"
# sequence parallel # sequence parallel
SEQUENCE = 'sequence' SEQUENCE = "sequence"
SEQUENCE_DP = 'sequence_dp' SEQUENCE_DP = "sequence_dp"
# 1D Parallel # 1D Parallel
PARALLEL_1D = '1d' PARALLEL_1D = "1d"
# 2D parallel # 2D parallel
PARALLEL_2D_ROW = '2d_row' PARALLEL_2D_ROW = "2d_row"
PARALLEL_2D_COL = '2d_col' PARALLEL_2D_COL = "2d_col"
# 3D parallel # 3D parallel
PARALLEL_3D_INPUT = '3d_input' PARALLEL_3D_INPUT = "3d_input"
PARALLEL_3D_WEIGHT = '3d_weight' PARALLEL_3D_WEIGHT = "3d_weight"
PARALLEL_3D_OUTPUT = '3d_output' PARALLEL_3D_OUTPUT = "3d_output"
PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight"
PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight"
# 2.5D parallel # 2.5D parallel
PARALLEL_2P5D_ROW = '2p5d_row' PARALLEL_2P5D_ROW = "2p5d_row"
PARALLEL_2P5D_COL = '2p5d_col' PARALLEL_2P5D_COL = "2p5d_col"
PARALLEL_2P5D_DEP = '2p5d_dep' PARALLEL_2P5D_DEP = "2p5d_dep"
PARALLEL_2P5D_XZ = '2p5d_xz' PARALLEL_2P5D_XZ = "2p5d_xz"
...@@ -10,6 +10,14 @@ from .initializer_tensor import Initializer_Tensor ...@@ -10,6 +10,14 @@ from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
__all__ = [ __all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', "Initializer_Tensor",
'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' "Initializer_Sequence",
"Initializer_Pipeline",
"Initializer_Data",
"Initializer_2p5D",
"Initializer_2D",
"Initializer_3D",
"Initializer_1D",
"ProcessGroupInitializer",
"Initializer_Model",
] ]
...@@ -45,7 +45,7 @@ class Initializer_1D(ProcessGroupInitializer): ...@@ -45,7 +45,7 @@ class Initializer_1D(ProcessGroupInitializer):
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
......
...@@ -14,9 +14,10 @@ def _check_summa_env_var(summa_dim): ...@@ -14,9 +14,10 @@ def _check_summa_env_var(summa_dim):
env_summa_dim = env.summa_dim env_summa_dim = env.summa_dim
if env_summa_dim: if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \ assert int(env_summa_dim) == summa_dim, (
'SUMMA_DIM has been set in the current environment and ' \ "SUMMA_DIM has been set in the current environment and "
'does not match with the value passed to this initialized' "does not match with the value passed to this initialized"
)
else: else:
env.summa_dim = summa_dim env.summa_dim = summa_dim
...@@ -57,7 +58,7 @@ class Initializer_2D_Row(ProcessGroupInitializer): ...@@ -57,7 +58,7 @@ class Initializer_2D_Row(ProcessGroupInitializer):
for j in range(self.summa_dim): for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -106,7 +107,7 @@ class Initializer_2D_Col(ProcessGroupInitializer): ...@@ -106,7 +107,7 @@ class Initializer_2D_Col(ProcessGroupInitializer):
for j in range(self.summa_dim): for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -137,8 +138,9 @@ class Initializer_2D(ProcessGroupInitializer): ...@@ -137,8 +138,9 @@ class Initializer_2D(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert self.tensor_parallel_size == self.summa_dim ** 2, \ assert (
"2D summa dim should equal to tensor parallel size ^ 0.5" self.tensor_parallel_size == self.summa_dim**2
), "2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim) _check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
......
...@@ -19,12 +19,14 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): ...@@ -19,12 +19,14 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int):
env_tesseract_dep = env.tesseract_dep env_tesseract_dep = env.tesseract_dep
if env_tesseract_dim and env_tesseract_dep: if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \ assert int(env_tesseract_dim) == tesseract_dim, (
'TESSERACT_DIM has been set in the current environment and ' \ "TESSERACT_DIM has been set in the current environment and "
'does not match with the value passed to this initialized' "does not match with the value passed to this initialized"
assert int(env_tesseract_dep) == tesseract_dep, \ )
'TESSERACT_DEP has been set in the current environment and ' \ assert int(env_tesseract_dep) == tesseract_dep, (
'does not match with the value passed to this initialized' "TESSERACT_DEP has been set in the current environment and "
"does not match with the value passed to this initialized"
)
else: else:
env.tesseract_dim = tesseract_dim env.tesseract_dim = tesseract_dim
env.tesseract_dep = tesseract_dep env.tesseract_dep = tesseract_dep
...@@ -50,8 +52,9 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): ...@@ -50,8 +52,9 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ assert (
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep
), "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
...@@ -75,7 +78,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): ...@@ -75,7 +78,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
for i in range(self.tesseract_dim) for i in range(self.tesseract_dim)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -129,7 +132,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): ...@@ -129,7 +132,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
for j in range(self.tesseract_dim) for j in range(self.tesseract_dim)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -183,7 +186,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): ...@@ -183,7 +186,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
for k in range(self.tesseract_dep) for k in range(self.tesseract_dep)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -238,7 +241,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer): ...@@ -238,7 +241,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
for j in range(self.tesseract_dim) for j in range(self.tesseract_dim)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -265,16 +268,25 @@ class Initializer_2p5D(ProcessGroupInitializer): ...@@ -265,16 +268,25 @@ class Initializer_2p5D(ProcessGroupInitializer):
depth (int): The depth of 2.5d parallel. depth (int): The depth of 2.5d parallel.
""" """
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, def __init__(
tensor_parallel_size: int, depth: int): self,
rank: int,
world_size: int,
config: Config,
data_parallel_size: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
depth: int,
):
args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size)
super().__init__(*args) super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth))
self.tesseract_dep = depth self.tesseract_dep = depth
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ assert (
"2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep
), "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5"
_check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) _check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep)
self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args)
...@@ -293,6 +305,6 @@ class Initializer_2p5D(ProcessGroupInitializer): ...@@ -293,6 +305,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
self.col_initializer.init_dist_group(), self.col_initializer.init_dist_group(),
self.row_initializer.init_dist_group(), self.row_initializer.init_dist_group(),
self.dep_initializer.init_dist_group(), self.dep_initializer.init_dist_group(),
self.xz_initializer.init_dist_group() self.xz_initializer.init_dist_group(),
] ]
return parallel_setting return parallel_setting
...@@ -17,9 +17,10 @@ def _check_depth_env_var(depth): ...@@ -17,9 +17,10 @@ def _check_depth_env_var(depth):
env_depth = env.depth_3d env_depth = env.depth_3d
if env_depth: if env_depth:
assert int(env_depth) == depth, \ assert int(env_depth) == depth, (
'DEPTH_3D has been set in the current environment and ' \ "DEPTH_3D has been set in the current environment and "
'does not match with the value passed to this initialized' "does not match with the value passed to this initialized"
)
else: else:
env.depth_3d = depth env.depth_3d = depth
...@@ -63,7 +64,7 @@ class Initializer_3D_Input(ProcessGroupInitializer): ...@@ -63,7 +64,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
for k in range(self.depth): for k in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -114,7 +115,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer): ...@@ -114,7 +115,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
for j in range(self.depth): for j in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -165,7 +166,7 @@ class Initializer_3D_Output(ProcessGroupInitializer): ...@@ -165,7 +166,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
for j in range(self.depth): for j in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -219,7 +220,7 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer): ...@@ -219,7 +220,7 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer):
for i in range(self.depth) for i in range(self.depth)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -273,7 +274,7 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer): ...@@ -273,7 +274,7 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
for i in range(self.depth) for i in range(self.depth)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -302,8 +303,9 @@ class Initializer_3D(ProcessGroupInitializer): ...@@ -302,8 +303,9 @@ class Initializer_3D(ProcessGroupInitializer):
super().__init__(*args) super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3))
assert self.tensor_parallel_size == self.depth ** 3, \ assert (
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' self.tensor_parallel_size == self.depth**3
), f"3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})"
_check_depth_env_var(self.depth) _check_depth_env_var(self.depth)
self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
...@@ -324,6 +326,6 @@ class Initializer_3D(ProcessGroupInitializer): ...@@ -324,6 +326,6 @@ class Initializer_3D(ProcessGroupInitializer):
self.weight_initializer.init_dist_group(), self.weight_initializer.init_dist_group(),
self.output_initializer.init_dist_group(), self.output_initializer.init_dist_group(),
self.input_x_weight_initializer.init_dist_group(), self.input_x_weight_initializer.init_dist_group(),
self.output_x_weight_initializer.init_dist_group() self.output_x_weight_initializer.init_dist_group(),
] ]
return parallel_setting return parallel_setting
...@@ -43,7 +43,7 @@ class Initializer_Data(ProcessGroupInitializer): ...@@ -43,7 +43,7 @@ class Initializer_Data(ProcessGroupInitializer):
for i in range(self.num_data_parallel_group): for i in range(self.num_data_parallel_group):
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
......
...@@ -45,7 +45,7 @@ class Initializer_Model(ProcessGroupInitializer): ...@@ -45,7 +45,7 @@ class Initializer_Model(ProcessGroupInitializer):
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
......
...@@ -38,10 +38,11 @@ class Initializer_Pipeline(ProcessGroupInitializer): ...@@ -38,10 +38,11 @@ class Initializer_Pipeline(ProcessGroupInitializer):
for i in range(self.data_parallel_size): for i in range(self.data_parallel_size):
for j in range(self.pipeline_stage_size): for j in range(self.pipeline_stage_size):
pipe_ranks = list( pipe_ranks = list(
range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)
)
pipe_group_size = len(pipe_ranks) pipe_group_size = len(pipe_ranks)
pipe_group = dist.new_group(pipe_ranks) pipe_group = dist.new_group(pipe_ranks)
group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group group_cpu = dist.new_group(pipe_ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
if self.rank in pipe_ranks: if self.rank in pipe_ranks:
local_rank = pipe_ranks.index(self.rank) local_rank = pipe_ranks.index(self.rank)
...@@ -50,7 +51,16 @@ class Initializer_Pipeline(ProcessGroupInitializer): ...@@ -50,7 +51,16 @@ class Initializer_Pipeline(ProcessGroupInitializer):
cpu_group = group_cpu cpu_group = group_cpu
ranks_in_group = pipe_ranks ranks_in_group = pipe_ranks
dist_settings.append( dist_settings.append(
tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, tuple(
ParallelMode.PIPELINE))) (
local_rank,
group_world_size,
process_group,
cpu_group,
ranks_in_group,
ParallelMode.PIPELINE,
)
)
)
return dist_settings return dist_settings
...@@ -46,7 +46,7 @@ class Initializer_Sequence_DP(ProcessGroupInitializer): ...@@ -46,7 +46,7 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.dp_size + j for j in range(self.dp_size)] ranks = [i * self.dp_size + j for j in range(self.dp_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
...@@ -91,8 +91,14 @@ class Initializer_Sequence(ProcessGroupInitializer): ...@@ -91,8 +91,14 @@ class Initializer_Sequence(ProcessGroupInitializer):
parallel_setting = [] parallel_setting = []
local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode = \ (
self._sequence_initializer.init_dist_group() local_rank,
group_world_size,
process_group,
cpu_group,
ranks_in_group,
mode,
) = self._sequence_initializer.init_dist_group()
# change mode to sequence # change mode to sequence
mode = ParallelMode.SEQUENCE mode = ParallelMode.SEQUENCE
......
...@@ -43,7 +43,7 @@ class Initializer_Tensor(ProcessGroupInitializer): ...@@ -43,7 +43,7 @@ class Initializer_Tensor(ProcessGroupInitializer):
for i in range(self.num_tensor_parallel_group): for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
if self.rank in ranks: if self.rank in ranks:
local_rank = ranks.index(self.rank) local_rank = ranks.index(self.rank)
......
...@@ -18,8 +18,15 @@ class ProcessGroupInitializer(ABC): ...@@ -18,8 +18,15 @@ class ProcessGroupInitializer(ABC):
tensor_parallel_size (int): Size of tensor parallel. tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, def __init__(
tensor_parallel_size: int): self,
rank: int,
world_size: int,
config: Config,
data_parallel_size: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
):
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
......
...@@ -13,6 +13,15 @@ from ._helper import ( ...@@ -13,6 +13,15 @@ from ._helper import (
) )
__all__ = [ __all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', "seed",
'sync_states', 'moe_set_seed', 'reset_seeds' "set_mode",
"with_seed",
"add_seed",
"get_seeds",
"get_states",
"get_current_mode",
"set_seed_states",
"sync_states",
"moe_set_seed",
"reset_seeds",
] ]
...@@ -100,7 +100,7 @@ def sync_states(): ...@@ -100,7 +100,7 @@ def sync_states():
@contextmanager @contextmanager
def seed(parallel_mode: ParallelMode): def seed(parallel_mode: ParallelMode):
""" A context for seed switch """A context for seed switch
Examples: Examples:
...@@ -162,6 +162,7 @@ def with_seed(func, parallel_mode: ParallelMode): ...@@ -162,6 +162,7 @@ def with_seed(func, parallel_mode: ParallelMode):
def moe_set_seed(seed): def moe_set_seed(seed):
if torch.cuda.is_available(): if torch.cuda.is_available():
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
global_rank = gpc.get_global_rank() global_rank = gpc.get_global_rank()
diff_seed = seed + global_rank diff_seed = seed + global_rank
add_seed(ParallelMode.TENSOR, diff_seed, True) add_seed(ParallelMode.TENSOR, diff_seed, True)
......
...@@ -42,7 +42,7 @@ class SeedManager: ...@@ -42,7 +42,7 @@ class SeedManager:
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
""" """
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' assert parallel_mode in self._seed_states, f"Parallel mode {parallel_mode} is not found in the seed manager"
self._seed_states[parallel_mode] = state self._seed_states[parallel_mode] = state
def set_mode(self, parallel_mode: ParallelMode): def set_mode(self, parallel_mode: ParallelMode):
...@@ -71,9 +71,9 @@ class SeedManager: ...@@ -71,9 +71,9 @@ class SeedManager:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode`
or the seed for `parallel_mode` has been added. or the seed for `parallel_mode` has been added.
""" """
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided"
if overwrite is False: if overwrite is False:
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added"
elif parallel_mode in self._seed_states: elif parallel_mode in self._seed_states:
print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True) print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True)
......
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
from colossalai.legacy.context.parallel_context import global_context from colossalai.legacy.context.parallel_context import global_context
__all__ = ['global_context'] __all__ = ["global_context"]
from ._base_engine import Engine from ._base_engine import Engine
from .gradient_handler import * from .gradient_handler import *
__all__ = ['Engine'] __all__ = ["Engine"]
...@@ -59,15 +59,17 @@ class Engine: ...@@ -59,15 +59,17 @@ class Engine:
`Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_. `Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_.
""" """
def __init__(self, def __init__(
model: Module, self,
optimizer: "OptimizerWrapper", model: Module,
criterion: Optional[_Loss] = None, optimizer: "OptimizerWrapper",
gradient_handlers: Optional[List[BaseGradientHandler]] = None, criterion: Optional[_Loss] = None,
clip_grad_norm: float = 0.0, gradient_handlers: Optional[List[BaseGradientHandler]] = None,
ophook_list: Optional[List[BaseOpHook]] = None, clip_grad_norm: float = 0.0,
verbose: bool = True, ophook_list: Optional[List[BaseOpHook]] = None,
schedule: Optional[BaseSchedule] = None): verbose: bool = True,
schedule: Optional[BaseSchedule] = None,
):
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._criterion = criterion self._criterion = criterion
...@@ -76,7 +78,7 @@ class Engine: ...@@ -76,7 +78,7 @@ class Engine:
self._logger = get_dist_logger() self._logger = get_dist_logger()
# state # state
self.training = True # default self.training = True # default
# build gradient handler # build gradient handler
if gradient_handlers: if gradient_handlers:
...@@ -91,8 +93,9 @@ class Engine: ...@@ -91,8 +93,9 @@ class Engine:
# build schedule # build schedule
if schedule: if schedule:
assert isinstance(schedule, BaseSchedule), \ assert isinstance(
f'expected schedule to be of type BaseSchedule, but got {type(schedule)}' schedule, BaseSchedule
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
self._schedule = schedule self._schedule = schedule
else: else:
self._schedule = NonPipelineSchedule() self._schedule = NonPipelineSchedule()
...@@ -149,13 +152,11 @@ class Engine: ...@@ -149,13 +152,11 @@ class Engine:
logger.warning(f"removing hooks is currently not supported") logger.warning(f"removing hooks is currently not supported")
def zero_grad(self): def zero_grad(self):
"""Set the gradient of parameters to zero """Set the gradient of parameters to zero"""
"""
self.optimizer.zero_grad() self.optimizer.zero_grad()
def step(self): def step(self):
"""Execute parameter update """Execute parameter update"""
"""
self._all_reduce_gradients() self._all_reduce_gradients()
self.optimizer.clip_grad_by_norm(self._clip_grad_norm) self.optimizer.clip_grad_by_norm(self._clip_grad_norm)
return self.optimizer.step() return self.optimizer.step()
...@@ -192,8 +193,7 @@ class Engine: ...@@ -192,8 +193,7 @@ class Engine:
return self.model(*args, **kwargs) return self.model(*args, **kwargs)
def _all_reduce_gradients(self): def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups. """Handles all-reduce operations of gradients across different parallel groups."""
"""
for handler in self._gradient_handlers: for handler in self._gradient_handlers:
handler.handle_gradient() handler.handle_gradient()
...@@ -208,13 +208,11 @@ class Engine: ...@@ -208,13 +208,11 @@ class Engine:
return output, label, loss return output, label, loss
def train(self): def train(self):
"""Sets the model to training mode. """Sets the model to training mode."""
"""
self.training = True self.training = True
self._model.train() self._model.train()
def eval(self): def eval(self):
"""Sets the model to evaluation mode. """Sets the model to evaluation mode."""
"""
self.training = False self.training = False
self._model.eval() self._model.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment