Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
......@@ -12,11 +12,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase
__all__ = ['TorchDDPPlugin']
__all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
......@@ -49,25 +48,29 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
def save_sharded_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
def save_sharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint but only on master process.
"""
......@@ -76,7 +79,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
......@@ -109,20 +111,24 @@ class TorchDDPPlugin(DPPluginBase):
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""
def __init__(self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False) -> None:
def __init__(
self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
self.ddp_kwargs = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
def support_no_sync(self) -> bool:
return True
......@@ -131,13 +137,13 @@ class TorchDDPPlugin(DPPluginBase):
return False
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp16_apex', 'bf16', 'fp8']
return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
return ["cuda"]
def configure(
self,
......@@ -156,8 +162,7 @@ class TorchDDPPlugin(DPPluginBase):
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
......@@ -169,5 +174,5 @@ class TorchDDPPlugin(DPPluginBase):
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
......@@ -31,11 +31,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase
__all__ = ['TorchFSDPPlugin']
__all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
......@@ -69,26 +68,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to checkpoint but only on master process.
"""
......@@ -109,7 +118,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
class TorchFSDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
......@@ -119,7 +127,6 @@ class TorchFSDPModel(ModelWrapper):
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
......@@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__(
self,
......@@ -162,15 +169,18 @@ class TorchFSDPPlugin(DPPluginBase):
sync_module_states: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states)
self.fsdp_kwargs = dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
......@@ -184,13 +194,13 @@ class TorchFSDPPlugin(DPPluginBase):
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']
return ["fp16", "bf16"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
return ["cuda"]
def configure(
self,
......@@ -200,14 +210,13 @@ class TorchFSDPPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
......
......@@ -3,4 +3,4 @@ from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
......@@ -11,7 +11,7 @@ from colossalai.interface import ModelWrapper
from .utils import has_index_file
__all__ = ['CheckpointIO']
__all__ = ["CheckpointIO"]
class CheckpointIO(ABC):
......@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
def load_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
......@@ -98,14 +97,16 @@ class CheckpointIO(ABC):
return origin_model
def save_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
def save_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint.
......@@ -157,7 +158,7 @@ class CheckpointIO(ABC):
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
raise ValueError(f'Cannot find index file in {checkpoint}')
raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
......@@ -165,13 +166,15 @@ class CheckpointIO(ABC):
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024):
def save_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
......@@ -207,7 +210,6 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
......@@ -220,11 +222,17 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
"""
Save model to sharded checkpoint.
......@@ -236,7 +244,6 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
......@@ -249,7 +256,6 @@ class CheckpointIO(ABC):
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
......@@ -265,7 +271,6 @@ class CheckpointIO(ABC):
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
"""
pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
......@@ -276,11 +281,11 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to sharded checkpoint.
......@@ -291,7 +296,6 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
......@@ -303,7 +307,6 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
pass
# ============================================
# methods for loading and saving lr scheduler
......
......@@ -3,9 +3,8 @@ import logging
import os
from functools import reduce
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
from typing import Optional
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
......@@ -16,7 +15,6 @@ from .index_file import CheckpointIndexFile
from .utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
......@@ -33,7 +31,7 @@ from .utils import (
unwrap_optimizer,
)
__all__ = ['GeneralCheckpointIO']
__all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO):
......@@ -70,8 +68,10 @@ class GeneralCheckpointIO(CheckpointIO):
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
......@@ -123,19 +123,23 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False)
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
......@@ -150,13 +154,15 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
......@@ -175,26 +181,32 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
load shard model, load model from multiple files
"""
......@@ -219,7 +231,11 @@ class GeneralCheckpointIO(CheckpointIO):
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
import copy
import gc
import logging
import os
from pathlib import Path
......@@ -35,9 +34,9 @@ from .utils import (
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class HybridParallelCheckpointIO(GeneralCheckpointIO):
......@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
def __init__(self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True) -> None:
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
......@@ -68,17 +69,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = (zero_stage > 0)
self.use_zero = zero_stage > 0
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod
def _model_sharder(model: nn.Module,
prefix: str = '',
keep_vars: bool = False,
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
......@@ -103,8 +103,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
......@@ -114,20 +116,20 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
def _optimizer_sharder(optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024):
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
for param, state in optimizer.optim.state.items():
if param is None:
continue
......@@ -136,15 +138,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
else:
working_param = param
param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False)
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
......@@ -153,13 +157,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False) -> None:
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
......@@ -194,24 +200,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0)
control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
......@@ -228,15 +238,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
......@@ -259,9 +273,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
......@@ -305,11 +321,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True)
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
)
loaded_file.add(filename)
# Load parameters.
......@@ -319,15 +333,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
......@@ -352,12 +368,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
......@@ -393,18 +411,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
size_per_shard=size_per_shard)
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
......@@ -415,9 +436,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose:
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
......@@ -433,15 +456,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
......@@ -451,7 +478,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
......@@ -470,9 +496,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
......@@ -484,20 +512,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str): Not used.
"""
def _get_param_id_from_optimizer_param(param: torch.Tensor,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info['param2id'][id(working_param)]
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg['params']:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
id_map[param_id] = param
......@@ -505,28 +534,30 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({'param_groups': updated_groups})
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg['params']:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
......@@ -550,12 +581,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
working_param = self.master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info['param2shape'][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True)
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
......@@ -585,8 +614,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
def link_master_and_working_param(
self,
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
......@@ -604,7 +636,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
......@@ -614,12 +647,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
@staticmethod
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
inplace: bool) -> OrderedDict:
def gather_from_sharded_optimizer_state(
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
......@@ -641,14 +681,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
......@@ -661,9 +700,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return state_
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
original_shape: torch.Size, device: torch.device,
inplace: bool) -> OrderedDict:
def shard_from_complete_optimizer_state(
self,
state: OrderedDict,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
......@@ -681,8 +725,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
......
......@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint
__all__ = ['CheckpointIndexFile']
__all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile:
......@@ -50,7 +50,7 @@ class CheckpointIndexFile:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
with open(json_path, "r") as f:
index = json.load(f)
# assign attributes if exists
......@@ -75,7 +75,7 @@ class CheckpointIndexFile:
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
with open(json_path, "w") as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
......
# coding=utf-8
import copy
import os
import re
from collections import abc as container_abcs
......@@ -12,7 +11,7 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
......@@ -55,7 +54,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available.
"""
try:
import safetensors
return True
except ImportError:
return False
......@@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True
else:
return False
......@@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
if checkpoint_file_path.endswith('.safetensors'):
if checkpoint_file_path.endswith(".safetensors"):
return True
else:
return False
......@@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
partition_dim = dim
break
if partition_dim is not None:
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
f"The parameter isn't evenly distributed among tensor parallel group: \
assert (
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
......@@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
"""
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
'''
"""
unwrapped_optim = optimizer.optim
return unwrapped_optim
class StateDictSharder:
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
......@@ -159,13 +156,11 @@ class StateDictSharder:
return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
......@@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
return param_
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False,
use_pp_format: bool = False) -> int:
'''
def save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False,
use_pp_format: bool = False,
) -> int:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
......@@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
Returns:
int: the total size of shards
'''
"""
total_size = 0
shard_filenames = []
......@@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
states = state_dict["state"]
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
......@@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
assert checkpoint_file_path.endswith(
".safetensors"
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
......@@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path)
def clean_folder(checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False):
def clean_folder(
checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False,
):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
......@@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str,
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shard_filenames
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
......@@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
output_root_path = root_path.joinpath("dtensor")
# create directory
output_root_path.mkdir(exist_ok=True)
......@@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
......@@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
return ".safetensors"
else:
return '.bin'
return ".bin"
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
def generate_checkpoint_shard_file_name(
index: int, total_number: int, use_safetensors: bool, prefix: str = None
) -> str:
"""
Generate checkpoint shard file name.
......@@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
return f"{param_name}.{index}.{suffix}"
# ========================================
......@@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
def load_state_dict_into_model(model: nn.Module,
state_dict: torch.Tensor,
missing_keys: List,
strict: bool = False,
load_sub_module: bool = True):
def load_state_dict_into_model(
model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
......@@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module,
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
......@@ -560,10 +564,12 @@ def load_state_dict_into_model(model: nn.Module,
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
......@@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
......@@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
param_lens = (len(g['params']) for g in param_groups)
saved_lens = (len(g['params']) for g in saved_groups)
param_lens = (len(g["params"]) for g in param_groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)
# Creating mapping from id to parameters.
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in param_groups)))
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in param_groups)),
)
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
new_group['params'] = group['params']
new_group["params"] = group["params"]
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
optimizer.__dict__.update({'param_groups': updated_groups})
optimizer.__dict__.update({"param_groups": updated_groups})
return id_map
......@@ -628,7 +638,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if (key != "step"):
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
......@@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False)
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
......@@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.*json'))
index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
assert len(
index_files
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
assert (
len(index_files) == 1
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path):
......@@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict.
"""
assert not is_dtensor_checkpoint(checkpoint_file_path), \
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
assert not is_dtensor_checkpoint(
checkpoint_file_path
), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path):
assert is_safetensors_available(), \
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
assert (
is_safetensors_available()
), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors
from safetensors import safe_open
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
......@@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
......
from .cli import cli
__all__ = ['cli']
__all__ = ["cli"]
import click
from .check_installation import check_installation
__all__ = ['check']
__all__ = ["check"]
@click.command(help="Check if Colossal-AI is correct based on the given option")
@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly")
@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly")
def check(installation):
if installation:
check_installation()
......
......@@ -9,7 +9,7 @@ import colossalai
def to_click_output(val):
# installation check output to understandable symbols for readability
VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'}
VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"}
if val in VAL_TO_SYMBOL:
return VAL_TO_SYMBOL[val]
......@@ -55,8 +55,8 @@ def check_installation():
else:
torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
click.echo(f'#### Installation Report ####')
click.echo(f'\n------------ Environment ------------')
click.echo(f"#### Installation Report ####")
click.echo(f"\n------------ Environment ------------")
click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}")
click.echo(f"PyTorch version: {to_click_output(torch_version)}")
click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
......@@ -69,7 +69,7 @@ def check_installation():
f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
)
click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------')
click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------")
click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
......@@ -81,7 +81,7 @@ def check_installation():
click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
click.echo(f"\n------------ Compatibility ------------")
click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}')
click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}")
click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
click.echo(f"")
......@@ -106,12 +106,12 @@ def _is_compatible(versions):
return False
# split version into [major, minor, patch]
versions = [version.split('.') for version in versions]
versions = [version.split(".") for version in versions]
for version in versions:
if len(version) == 2:
# x means unknown
version.append('x')
version.append("x")
for idx, version_values in enumerate(zip(*versions)):
equal = len(set(version_values)) == 1
......@@ -137,11 +137,11 @@ def _parse_colossalai_version():
# 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
# 2. X.X.X (when colossalai is not installed with CUDA extensions)
# where X represents an integer.
colossalai_version = colossalai.__version__.split('+')[0]
colossalai_version = colossalai.__version__.split("+")[0]
try:
torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0]
cuda_version_for_aot_build = colossalai.__version__.split('cu')[1]
torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0]
cuda_version_for_aot_build = colossalai.__version__.split("cu")[1]
except:
torch_version_for_aot_build = None
cuda_version_for_aot_build = None
......@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed():
JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
"""
try:
import colossalai._C.fused_optim
found_aot_cuda_ext = True
except ImportError:
found_aot_cuda_ext = False
......@@ -175,14 +174,14 @@ def _check_torch_version():
# torch version can be of two formats
# - 1.13.1+cu113
# - 1.13.1.devxxx
torch_version = torch.__version__.split('+')[0]
torch_version = '.'.join(torch_version.split('.')[:3])
torch_version = torch.__version__.split("+")[0]
torch_version = ".".join(torch_version.split(".")[:3])
# get cuda version in pytorch build
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}"
except:
torch_cuda_version = None
......@@ -208,7 +207,7 @@ def _check_cuda_version():
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
cuda_version = f"{bare_metal_major}.{bare_metal_minor}"
except:
cuda_version = None
return cuda_version
......@@ -4,8 +4,7 @@ from .check import check
from .launcher import run
class Arguments():
class Arguments:
def __init__(self, arg_dict):
for k, v in arg_dict.items():
self.__dict__[k] = v
......@@ -19,5 +18,5 @@ def cli():
cli.add_command(run)
cli.add_command(check)
if __name__ == '__main__':
if __name__ == "__main__":
cli()
......@@ -5,56 +5,81 @@ from colossalai.context import Config
from .run import launch_multi_processes
@click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True))
@click.option("-H",
"-host",
"--host",
type=str,
default=None,
help="the list of hostnames to launch in the format <host1>,<host2>")
@click.command(
help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True),
)
@click.option(
"-H",
"-host",
"--host",
type=str,
default=None,
help="the list of hostnames to launch in the format <host1>,<host2>",
)
@click.option(
"--hostfile",
type=str,
default=None,
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
@click.option("--include",
type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.")
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
)
@click.option(
"--include",
type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.",
)
@click.option(
"--exclude",
type=str,
default=None,
help=
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
" only effective when used with --hostfile.")
@click.option("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.")
help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
" only effective when used with --hostfile.",
)
@click.option(
"--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.",
)
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
@click.option("--master_port",
type=int,
default=29500,
help="(optional) Port used by PyTorch distributed for communication during distributed training.")
@click.option("--master_addr",
type=str,
default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
@click.option(
"--master_port",
type=int,
default=29500,
help="(optional) Port used by PyTorch distributed for communication during distributed training.",
)
@click.option(
"--master_addr",
type=str,
default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
)
@click.option(
"--extra_launch_args",
type=str,
default=None,
help=
"Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
"This will be converted to --arg1=1 --arg2=2 during execution")
help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
"This will be converted to --arg1=1 --arg2=2 during execution",
)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
@click.argument("user_args", nargs=-1)
def run(
host: str,
hostfile: str,
num_nodes: int,
nproc_per_node: int,
include: str,
exclude: str,
master_addr: str,
master_port: int,
extra_launch_args: str,
ssh_port: int,
user_script: str,
user_args: str,
) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
......@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
if not user_script.endswith('.py'):
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
if not user_script.endswith(".py"):
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
args_dict = locals()
......
import socket
from typing import List
class HostInfo:
......@@ -34,7 +33,7 @@ class HostInfo:
"""
if port is None:
port = 22 # no port specified, lets just use the ssh port
port = 22 # no port specified, lets just use the ssh port
# socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines
......@@ -50,7 +49,7 @@ class HostInfo:
return localaddrs == targetaddrs
def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}'
return f"hostname: {self.hostname}, port: {self.port}"
def __repr__(self):
return self.__str__()
......
......@@ -7,8 +7,13 @@ import fabric
from .hostinfo import HostInfo, HostInfoList
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection, env: dict) -> None:
def run_on_host(
hostinfo: HostInfo,
workdir: str,
recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection,
env: dict,
) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
......@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
if cmds == 'exit':
if cmds == "exit":
# exit from the loop
finish = True
break
......@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send('success')
send_conn.send("success")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
send_conn.send('failure')
send_conn.send("failure")
# shutdown
send_conn.send("finish")
......@@ -96,8 +101,7 @@ class MultiNodeRunner:
cmd (str): the command to execute
"""
assert hostinfo.hostname in self.master_send_conns, \
f'{hostinfo} is not found in the current connections'
assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
......@@ -107,7 +111,7 @@ class MultiNodeRunner:
"""
for hostname, conn in self.master_send_conns.items():
conn.send('exit')
conn.send("exit")
def recv_from_all(self) -> dict:
"""
......
......@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
NODE_SEP = ','
NODE_SEP = ","
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
......@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit()
with open(hostfile_path, 'r') as fd:
with open(hostfile_path, "r") as fd:
device_pool = HostInfoList()
for line in fd.readlines():
line = line.strip()
if line == '':
if line == "":
# skip empty lines
continue
......@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
"""Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
......@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
"""
# Ensure include/exclude are mutually exclusive
if include_str and exclude_str:
......@@ -136,16 +136,16 @@ def get_launch_command(
for k, v in arg_dict.items():
if v:
ret.append(f'--{k}={v}')
ret.append(f"--{k}={v}")
else:
ret.append(f'--{k}')
ret.append(f"--{k}")
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
for arg in extra_launch_args.split(','):
if '=' in arg:
k, v = arg.split('=')
for arg in extra_launch_args.split(","):
if "=" in arg:
k, v = arg.split("=")
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
......@@ -158,9 +158,14 @@ def get_launch_command(
if torch_version.minor < 9:
cmd = [
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
f"--node_rank={node_rank}"
sys.executable,
"-m",
"torch.distributed.launch",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={master_addr}",
f"--master_port={master_port}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
......@@ -174,17 +179,24 @@ def get_launch_command(
if torch_version.minor < 10:
cmd = [
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
sys.executable,
"-m",
"torch.distributed.run",
f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
]
else:
cmd = [
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
"torchrun",
f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd = ' '.join(cmd)
cmd = " ".join(cmd)
return cmd
......@@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None:
# run on local node if not hosts or hostfile is given
# add local node to host info list
active_device_pool = HostInfoList()
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
active_device_pool.append(localhost_info)
# launch distributed processes
runner = MultiNodeRunner()
curr_path = os.path.abspath('.')
curr_path = os.path.abspath(".")
# collect current path env
env = dict()
for k, v in os.environ.items():
# do not support multi-line env var
if v and '\n' not in v:
if v and "\n" not in v:
env[k] = v
# establish remote connection
......@@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None:
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
cmd = get_launch_command(master_addr=args.master_addr,
master_port=args.master_port,
nproc_per_node=args.nproc_per_node,
user_script=args.user_script,
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
extra_launch_args=args.extra_launch_args)
cmd = get_launch_command(
master_addr=args.master_addr,
master_port=args.master_port,
nproc_per_node=args.nproc_per_node,
user_script=args.user_script,
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
extra_launch_args=args.extra_launch_args,
)
runner.send(hostinfo=hostinfo, cmd=cmd)
# start training
......
......@@ -3,4 +3,4 @@ from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager
from .process_group_mesh import ProcessGroupMesh
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh']
__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
......@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh
@dataclass
class DeviceMeshInfo:
'''
"""
This class is used to store the information used to initialize the device mesh.
Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
'''
"""
physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
......@@ -24,16 +25,18 @@ class DeviceMeshInfo:
if self.mesh_shape is not None:
world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
assert (
world_size == mesh_shape_numel
), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
'''
"""
This method is used to initialize the device mesh.
Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
'''
"""
# parse the device mesh info
physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices)
......@@ -67,13 +70,13 @@ class DeviceMeshManager:
Args:
name (str): name of the device mesh
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
"""
"""
if name not in self.device_mesh_store:
device_mesh = initialize_device_mesh(device_mesh_info)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
raise ValueError(f'Device mesh {name} already exists.')
raise ValueError(f"Device mesh {name} already exists.")
def get(self, name: str) -> DeviceMesh:
"""
......@@ -88,7 +91,7 @@ class DeviceMeshManager:
if name in self.device_mesh_store:
return self.device_mesh_store[name]
else:
raise ValueError(f'Device mesh {name} does not exist.')
raise ValueError(f"Device mesh {name} does not exist.")
def destroy(self, name: str) -> None:
"""
......@@ -103,7 +106,7 @@ class DeviceMeshManager:
dist.destroy_process_group(pg)
del self.device_mesh_store[name]
else:
raise ValueError(f'Device mesh {name} does not exist.')
raise ValueError(f"Device mesh {name} does not exist.")
def destroy_all(self):
"""
......
......@@ -36,12 +36,13 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
def __init__(self):
assert dist.is_initialized(
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
assert (
dist.is_initialized()
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
self._local_rank = os.environ.get('LOCAL_RANK', -1)
self._local_rank = os.environ.get("LOCAL_RANK", -1)
@property
def rank(self) -> int:
......@@ -59,7 +60,9 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
assert (
self.local_rank >= 0
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
......@@ -183,7 +186,6 @@ class DistCoordinator(metaclass=SingletonMeta):
# define an inner function
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
......
......@@ -19,7 +19,7 @@ class ProcessGroupManager:
def __init__(self):
self.pg_store = dict()
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
......@@ -36,7 +36,7 @@ class ProcessGroupManager:
self.pg_store[name] = pg
return pg
else:
raise ValueError(f'Process group {name} already exists.')
raise ValueError(f"Process group {name} already exists.")
def get(self, name: str) -> ProcessGroup:
"""
......@@ -51,7 +51,7 @@ class ProcessGroupManager:
if name in self.pg_store:
return self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
raise ValueError(f"Process group {name} does not exist.")
def destroy(self, name: str) -> None:
"""
......@@ -64,7 +64,7 @@ class ProcessGroupManager:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
raise ValueError(f"Process group {name} does not exist.")
def destroy_all(self) -> None:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment