Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
...@@ -12,11 +12,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper ...@@ -12,11 +12,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
__all__ = ['TorchDDPPlugin'] __all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO): class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
...@@ -49,25 +48,29 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): ...@@ -49,25 +48,29 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self, def save_sharded_model(
model: nn.Module, self,
checkpoint_path: str, model: nn.Module,
gather_dtensor: bool = True, checkpoint_path: str,
prefix: Optional[str] = None, gather_dtensor: bool = True,
max_shard_size: int = 1024, prefix: Optional[str] = None,
use_safetensors: bool = False): max_shard_size: int = 1024,
use_safetensors: bool = False,
):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
def save_sharded_optimizer(self, def save_sharded_optimizer(
optimizer: Optimizer, self,
checkpoint: str, optimizer: Optimizer,
gather_dtensor: bool = True, checkpoint: str,
prefix: Optional[str] = None, gather_dtensor: bool = True,
size_per_shard: int = 1024): prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
...@@ -76,7 +79,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): ...@@ -76,7 +79,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None: def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module) super().__init__(module)
self.module = DDP(module, *args, **kwargs) self.module = DDP(module, *args, **kwargs)
...@@ -109,20 +111,24 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -109,20 +111,24 @@ class TorchDDPPlugin(DPPluginBase):
static_graph (bool, optional): Whether to use static graph. Defaults to False. static_graph (bool, optional): Whether to use static graph. Defaults to False.
""" """
def __init__(self, def __init__(
broadcast_buffers: bool = True, self,
bucket_cap_mb: int = 25, broadcast_buffers: bool = True,
find_unused_parameters: bool = False, bucket_cap_mb: int = 25,
check_reduction: bool = False, find_unused_parameters: bool = False,
gradient_as_bucket_view: bool = False, check_reduction: bool = False,
static_graph: bool = False) -> None: gradient_as_bucket_view: bool = False,
static_graph: bool = False,
) -> None:
super().__init__() super().__init__()
self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, self.ddp_kwargs = dict(
bucket_cap_mb=bucket_cap_mb, broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters, bucket_cap_mb=bucket_cap_mb,
check_reduction=check_reduction, find_unused_parameters=find_unused_parameters,
gradient_as_bucket_view=gradient_as_bucket_view, check_reduction=check_reduction,
static_graph=static_graph) gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
return True return True
...@@ -131,13 +137,13 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -131,13 +137,13 @@ class TorchDDPPlugin(DPPluginBase):
return False return False
def supported_precisions(self) -> List[str]: def supported_precisions(self) -> List[str]:
return ['fp16', 'fp16_apex', 'bf16', 'fp8'] return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool: def control_device(self) -> bool:
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ['cuda'] return ["cuda"]
def configure( def configure(
self, self,
...@@ -156,8 +162,7 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -156,8 +162,7 @@ class TorchDDPPlugin(DPPluginBase):
# wrap the model with PyTorch DDP # wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs) model = TorchDDPModel(model, **self.ddp_kwargs)
if optimizer is not None and \ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer) optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
...@@ -169,5 +174,5 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -169,5 +174,5 @@ class TorchDDPPlugin(DPPluginBase):
return TorchDDPCheckpointIO() return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync() return model.module.no_sync()
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterable, Iterator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging import version from packaging import version
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp import StateDictType
...@@ -31,11 +31,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper ...@@ -31,11 +31,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
__all__ = ['TorchFSDPPlugin'] __all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO): class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
...@@ -69,26 +68,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ...@@ -69,26 +68,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], def save_sharded_model(
size_per_shard: int, use_safetensors: bool): self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.
""" """
raise NotImplementedError("Sharded model checkpoint is not supported yet.") raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def load_sharded_model(self, def load_sharded_model(
model: nn.Module, self,
checkpoint_index_file: Path, model: nn.Module,
strict: bool = False, checkpoint_index_file: Path,
use_safetensors: bool = False, strict: bool = False,
load_sub_module: bool = True): use_safetensors: bool = False,
load_sub_module: bool = True,
):
""" """
Load model to checkpoint but only on master process. Load model to checkpoint but only on master process.
""" """
raise NotImplementedError("Sharded model checkpoint is not supported yet.") raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, def save_sharded_optimizer(
size_per_shard: int): self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
...@@ -109,7 +118,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ...@@ -109,7 +118,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
class TorchFSDPModel(ModelWrapper): class TorchFSDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None: def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module) super().__init__(module)
self.module = FSDP(module, *args, **kwargs) self.module = FSDP(module, *args, **kwargs)
...@@ -119,7 +127,6 @@ class TorchFSDPModel(ModelWrapper): ...@@ -119,7 +127,6 @@ class TorchFSDPModel(ModelWrapper):
class FSDPOptimizerWrapper(OptimizerWrapper): class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module): def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model self.model = model
super().__init__(optimizer) super().__init__(optimizer)
...@@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase): ...@@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details. See https://pytorch.org/docs/stable/fsdp.html for details.
""" """
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__( def __init__(
self, self,
...@@ -162,15 +169,18 @@ class TorchFSDPPlugin(DPPluginBase): ...@@ -162,15 +169,18 @@ class TorchFSDPPlugin(DPPluginBase):
sync_module_states: bool = False, sync_module_states: bool = False,
): ):
super().__init__() super().__init__()
self.fsdp_kwargs = dict(process_group=process_group, self.fsdp_kwargs = dict(
sharding_strategy=sharding_strategy, process_group=process_group,
cpu_offload=cpu_offload, sharding_strategy=sharding_strategy,
auto_wrap_policy=auto_wrap_policy, cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch, auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision, backward_prefetch=backward_prefetch,
ignored_modules=ignored_modules, mixed_precision=mixed_precision,
param_init_fn=param_init_fn, ignored_modules=ignored_modules,
sync_module_states=sync_module_states) param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
else: else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
...@@ -184,13 +194,13 @@ class TorchFSDPPlugin(DPPluginBase): ...@@ -184,13 +194,13 @@ class TorchFSDPPlugin(DPPluginBase):
return True return True
def supported_precisions(self) -> List[str]: def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16'] return ["fp16", "bf16"]
def control_device(self) -> bool: def control_device(self) -> bool:
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ['cuda'] return ["cuda"]
def configure( def configure(
self, self,
...@@ -200,14 +210,13 @@ class TorchFSDPPlugin(DPPluginBase): ...@@ -200,14 +210,13 @@ class TorchFSDPPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP # wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None: if optimizer is not None:
if len(optimizer.param_groups) > 1: if len(optimizer.param_groups) > 1:
warnings.warn( warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
) )
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
......
...@@ -3,4 +3,4 @@ from .general_checkpoint_io import GeneralCheckpointIO ...@@ -3,4 +3,4 @@ from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] __all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
...@@ -11,7 +11,7 @@ from colossalai.interface import ModelWrapper ...@@ -11,7 +11,7 @@ from colossalai.interface import ModelWrapper
from .utils import has_index_file from .utils import has_index_file
__all__ = ['CheckpointIO'] __all__ = ["CheckpointIO"]
class CheckpointIO(ABC): class CheckpointIO(ABC):
...@@ -61,10 +61,9 @@ class CheckpointIO(ABC): ...@@ -61,10 +61,9 @@ class CheckpointIO(ABC):
# ====================================== # ======================================
# Public methods # Public methods
# ====================================== # ======================================
def load_model(self, def load_model(
model: Union[nn.Module, ModelWrapper], self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
checkpoint: str, ) -> Union[nn.Module, ModelWrapper]:
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
""" """
Load model from checkpoint. Load model from checkpoint.
...@@ -98,14 +97,16 @@ class CheckpointIO(ABC): ...@@ -98,14 +97,16 @@ class CheckpointIO(ABC):
return origin_model return origin_model
def save_model(self, def save_model(
model: Union[nn.Module, ModelWrapper], self,
checkpoint: str, model: Union[nn.Module, ModelWrapper],
shard: bool = False, checkpoint: str,
gather_dtensor: bool = True, shard: bool = False,
prefix: str = None, gather_dtensor: bool = True,
size_per_shard: int = 1024, prefix: str = None,
use_safetensors: bool = False): size_per_shard: int = 1024,
use_safetensors: bool = False,
):
""" """
Save model to checkpoint. Save model to checkpoint.
...@@ -157,7 +158,7 @@ class CheckpointIO(ABC): ...@@ -157,7 +158,7 @@ class CheckpointIO(ABC):
if Path(checkpoint).is_dir() and not index_file_exists: if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error # if the checkpoint is a directory and there is no index file, raise error
raise ValueError(f'Cannot find index file in {checkpoint}') raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists: if index_file_exists:
# the existence of index file means it is a sharded checkpoint # the existence of index file means it is a sharded checkpoint
...@@ -165,13 +166,15 @@ class CheckpointIO(ABC): ...@@ -165,13 +166,15 @@ class CheckpointIO(ABC):
else: else:
self.load_unsharded_optimizer(optimizer, checkpoint) self.load_unsharded_optimizer(optimizer, checkpoint)
def save_optimizer(self, def save_optimizer(
optimizer: Optimizer, self,
checkpoint: str, optimizer: Optimizer,
shard: bool = False, checkpoint: str,
gather_dtensor=True, shard: bool = False,
prefix: str = None, gather_dtensor=True,
size_per_shard: int = 1024): prefix: str = None,
size_per_shard: int = 1024,
):
""" """
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
...@@ -207,7 +210,6 @@ class CheckpointIO(ABC): ...@@ -207,7 +210,6 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's. the checkpoint match the keys returned by this module's.
""" """
pass
@abstractmethod @abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
...@@ -220,11 +222,17 @@ class CheckpointIO(ABC): ...@@ -220,11 +222,17 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's. the checkpoint match the keys returned by this module's.
""" """
pass
@abstractmethod @abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], def save_sharded_model(
size_per_shard: int, use_safetensors: bool): self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
""" """
Save model to sharded checkpoint. Save model to sharded checkpoint.
...@@ -236,7 +244,6 @@ class CheckpointIO(ABC): ...@@ -236,7 +244,6 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB. size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors. use_safetensors (bool): whether to use safe tensors.
""" """
pass
@abstractmethod @abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
...@@ -249,7 +256,6 @@ class CheckpointIO(ABC): ...@@ -249,7 +256,6 @@ class CheckpointIO(ABC):
gather_dtensor (bool): whether to gather the distributed tensor to the first device. gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors. use_safetensors (bool): whether to use safe tensors.
""" """
pass
# ======================================================== # ========================================================
# Abstract methods for optimizer loading/saving implementation # Abstract methods for optimizer loading/saving implementation
...@@ -265,7 +271,6 @@ class CheckpointIO(ABC): ...@@ -265,7 +271,6 @@ class CheckpointIO(ABC):
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint. prefix (str): prefix for the optimizer checkpoint.
""" """
pass
@abstractmethod @abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
...@@ -276,11 +281,11 @@ class CheckpointIO(ABC): ...@@ -276,11 +281,11 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded. optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
""" """
pass
@abstractmethod @abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, def save_sharded_optimizer(
size_per_shard: int): self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
""" """
Save optimizer to sharded checkpoint. Save optimizer to sharded checkpoint.
...@@ -291,7 +296,6 @@ class CheckpointIO(ABC): ...@@ -291,7 +296,6 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint. prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB. size_per_shard (int): size per shard in MB.
""" """
pass
@abstractmethod @abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
...@@ -303,7 +307,6 @@ class CheckpointIO(ABC): ...@@ -303,7 +307,6 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. gather_dtensor (bool): whether to gather the distributed tensor to the first device.
""" """
pass
# ============================================ # ============================================
# methods for loading and saving lr scheduler # methods for loading and saving lr scheduler
......
...@@ -3,9 +3,8 @@ import logging ...@@ -3,9 +3,8 @@ import logging
import os import os
from functools import reduce from functools import reduce
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple from typing import Optional
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -16,7 +15,6 @@ from .index_file import CheckpointIndexFile ...@@ -16,7 +15,6 @@ from .index_file import CheckpointIndexFile
from .utils import ( from .utils import (
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
get_shard_filename,
is_safetensors_available, is_safetensors_available,
load_param_groups_into_optimizer, load_param_groups_into_optimizer,
load_shard_state_dict, load_shard_state_dict,
...@@ -33,7 +31,7 @@ from .utils import ( ...@@ -33,7 +31,7 @@ from .utils import (
unwrap_optimizer, unwrap_optimizer,
) )
__all__ = ['GeneralCheckpointIO'] __all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO): class GeneralCheckpointIO(CheckpointIO):
...@@ -70,8 +68,10 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -70,8 +68,10 @@ class GeneralCheckpointIO(CheckpointIO):
# Load param_groups # Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename() param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None: if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ raise RuntimeError(
Lacking param group file under current directory.') f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path) id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
...@@ -123,19 +123,23 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -123,19 +123,23 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states. # Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior. # In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=sharded_state, total_size = save_state_dict_shards(
checkpoint=checkpoint, sharded_state_dict=sharded_state,
index_file=index_file, checkpoint=checkpoint,
base_filename=states_name, index_file=index_file,
is_master=True, base_filename=states_name,
use_safetensors=False) is_master=True,
use_safetensors=False,
)
# Wrap up index file. # Wrap up index file.
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The optimizer is going to be split to checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint) checkpoint = load_state_dict(checkpoint)
...@@ -150,13 +154,15 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -150,13 +154,15 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors # TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self, def save_sharded_model(
model: nn.Module, self,
checkpoint_path: str, model: nn.Module,
gather_dtensor: bool = False, checkpoint_path: str,
prefix: Optional[str] = None, gather_dtensor: bool = False,
max_shard_size: int = 1024, prefix: Optional[str] = None,
use_safetensors: bool = False): max_shard_size: int = 1024,
use_safetensors: bool = False,
):
""" """
implement this method as it can be supported by Huggingface model, implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files save shard model, save model to multiple files
...@@ -175,26 +181,32 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -175,26 +181,32 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states. # Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior. # In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint_path, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint_path,
base_filename=weights_name, index_file=index_file,
is_master=True, base_filename=weights_name,
use_safetensors=use_safetensors) is_master=True,
use_safetensors=use_safetensors,
)
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True) save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The model is going to be split to checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
def load_sharded_model(self, )
model: nn.Module,
checkpoint_index_file: Path, def load_sharded_model(
strict: bool = False, self,
use_safetensors: bool = False, model: nn.Module,
load_sub_module: bool = True): checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
""" """
load shard model, load model from multiple files load shard model, load model from multiple files
""" """
...@@ -219,7 +231,11 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -219,7 +231,11 @@ class GeneralCheckpointIO(CheckpointIO):
if strict: if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0: if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join( error_msgs = "Missing key(s) in state_dict: {}. ".format(
'"{}"'.format(k) for k in missing_keys)) ", ".join('"{}"'.format(k) for k in missing_keys)
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( )
self.__class__.__name__, "\n\t".join(error_msgs))) raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
import copy import copy
import gc
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
...@@ -35,9 +34,9 @@ from .utils import ( ...@@ -35,9 +34,9 @@ from .utils import (
) )
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError: except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class HybridParallelCheckpointIO(GeneralCheckpointIO): class HybridParallelCheckpointIO(GeneralCheckpointIO):
...@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
""" """
def __init__(self, def __init__(
dp_group: ProcessGroup, self,
pp_group: ProcessGroup, dp_group: ProcessGroup,
tp_group: ProcessGroup, pp_group: ProcessGroup,
zero_stage: int, tp_group: ProcessGroup,
verbose: bool = True) -> None: zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__() super().__init__()
self.dp_group = dp_group self.dp_group = dp_group
self.pp_group = pp_group self.pp_group = pp_group
...@@ -68,17 +69,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -68,17 +69,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.dp_size = dist.get_world_size(dp_group) self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group) self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group) self.tp_size = dist.get_world_size(tp_group)
self.use_zero = (zero_stage > 0) self.use_zero = zero_stage > 0
self.verbose = verbose self.verbose = verbose
self.working_to_master_map = None self.working_to_master_map = None
self.master_to_working_map = None self.master_to_working_map = None
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
@staticmethod @staticmethod
def _model_sharder(model: nn.Module, def _model_sharder(
prefix: str = '', model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
keep_vars: bool = False, ) -> Iterator[Tuple[OrderedDict, int]]:
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size. # An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard) state_dict_sharder = StateDictSharder(size_per_shard)
...@@ -103,8 +103,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -103,8 +103,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Save extra states. # Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state", if (
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state() extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None: if block is not None:
...@@ -114,20 +116,20 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -114,20 +116,20 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod @staticmethod
def _optimizer_sharder(optimizer: OptimizerWrapper, def _optimizer_sharder(
use_zero: bool, optimizer: OptimizerWrapper,
dp_group: ProcessGroup, use_zero: bool,
tp_group: ProcessGroup, dp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, tp_group: ProcessGroup,
size_per_shard: int = 1024): master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size. # An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard) state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info param_info = optimizer.param_info
for param, state in optimizer.optim.state.items(): for param, state in optimizer.optim.state.items():
if param is None: if param is None:
continue continue
...@@ -136,15 +138,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -136,15 +138,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
else: else:
working_param = param working_param = param
param_id = param_info['param2id'][id(working_param)] param_id = param_info["param2id"][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)] original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
working_param, state,
original_shape=original_shape, working_param,
dp_group=dp_group, original_shape=original_shape,
tp_group=tp_group, dp_group=dp_group,
use_zero=use_zero, tp_group=tp_group,
inplace=False) use_zero=use_zero,
inplace=False,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_) block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None: if block is not None:
...@@ -153,13 +157,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -153,13 +157,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Return the last block in sharder. # Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(self, def save_sharded_model(
model: nn.Module, self,
checkpoint: str, model: nn.Module,
gather_dtensor: bool = True, checkpoint: str,
prefix: Optional[str] = None, gather_dtensor: bool = True,
size_per_shard: int = 1024, prefix: Optional[str] = None,
use_safetensors: bool = False) -> None: size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
""" """
Save sharded model checkpoint under the given checkpointing path. Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path: The following files will be created under the path:
...@@ -194,24 +200,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -194,24 +200,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint) index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0) control_saving = self.tp_rank == 0
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO # When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint,
base_filename=weights_name, index_file=index_file,
is_master=control_saving, base_filename=weights_name,
use_safetensors=use_safetensors) is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving: if control_saving:
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint) save_config_file(model, checkpoint)
if self.verbose: if self.verbose:
logging.info(f"The model is split into checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The model is split into checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else: else:
# When pipeline is used, each stage produces its own shard files and index files. # When pipeline is used, each stage produces its own shard files and index files.
...@@ -228,15 +238,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -228,15 +238,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file) save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint,
base_filename=weights_name, index_file=index_file,
is_master=control_saving, base_filename=weights_name,
use_safetensors=use_safetensors, is_master=control_saving,
use_pp_format=True) use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving: if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
else: else:
...@@ -259,9 +273,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -259,9 +273,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_config_file(model, checkpoint) save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder) rmtree(tmp_index_file_folder)
if self.verbose: if self.verbose:
logging.info(f"The model is split into checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The model is split into checkpoint shards. "
f"index located at {final_index_file_path}.") f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
""" """
...@@ -305,11 +321,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -305,11 +321,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = load_shard_state_dict(Path(file_path), use_safetensors) state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = [] missing_keys = []
load_state_dict_into_model(model, load_state_dict_into_model(
state_dict, model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
missing_keys=missing_keys, )
strict=strict,
load_sub_module=True)
loaded_file.add(filename) loaded_file.add(filename)
# Load parameters. # Load parameters.
...@@ -319,15 +333,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -319,15 +333,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Load buffers. # Load buffers.
non_persistent_buffers = set() non_persistent_buffers = set()
for n, m in model.named_modules(): for n, m in model.named_modules():
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers(): for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers: if buf is not None and name not in non_persistent_buffers:
_load(name) _load(name)
# Load extra states. # Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state", if (
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key) _load(extra_state_key)
# Update master params if mixed-precision training is enabled. # Update master params if mixed-precision training is enabled.
...@@ -352,12 +368,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -352,12 +368,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose: if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_sharded_optimizer(self, def save_sharded_optimizer(
optimizer: OptimizerWrapper, self,
checkpoint: str, optimizer: OptimizerWrapper,
gather_dtensor: bool = True, checkpoint: str,
prefix: Optional[str] = None, gather_dtensor: bool = True,
size_per_shard: int = 1024): prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
""" """
Save sharded optimizer checkpoint under the given checkpointing path. Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path: The following files will be created under the path:
...@@ -393,18 +411,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -393,18 +411,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group=self.dp_group, dp_group=self.dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map, master_to_working_map=self.master_to_working_map,
size_per_shard=size_per_shard) size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint) index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.dp_rank == 0 and self.tp_rank == 0) control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO # When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint,
base_filename=states_name, index_file=index_file,
is_master=control_saving) base_filename=states_name,
is_master=control_saving,
)
if control_saving: if control_saving:
# Store param groups. # Store param groups.
...@@ -415,9 +436,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -415,9 +436,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
if self.verbose: if self.verbose:
logging.info(f"The optimizer is going to be split to checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The optimizer is going to be split to checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else: else:
# When pipeline is used, each stage produces its own shard files and index files. # When pipeline is used, each stage produces its own shard files and index files.
...@@ -433,15 +456,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -433,15 +456,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file) save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint,
base_filename=states_name, index_file=index_file,
is_master=control_saving, base_filename=states_name,
use_pp_format=True) is_master=control_saving,
use_pp_format=True,
)
if control_saving: if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
else: else:
...@@ -451,7 +478,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -451,7 +478,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# The global master rank integrates the index files and clean the folder. # The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0: if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint) final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0) final_index_file.append_meta_data("total_size", 0)
...@@ -470,9 +496,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -470,9 +496,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
rmtree(tmp_index_file_folder) rmtree(tmp_index_file_folder)
if self.verbose: if self.verbose:
logging.info(f"The model is split into checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The model is split into checkpoint shards. "
f"index located at {final_index_file_path}.") f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
""" """
...@@ -484,20 +512,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -484,20 +512,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str): Not used. prefix (str): Not used.
""" """
def _get_param_id_from_optimizer_param(param: torch.Tensor, def _get_param_id_from_optimizer_param(
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None: if master_to_working_map is not None:
working_param = master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
return optimizer.param_info['param2id'][id(working_param)] return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters. # When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {} id_map = {}
for pg in optimizer.optim.param_groups: for pg in optimizer.optim.param_groups:
for param in pg['params']: for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
id_map[param_id] = param id_map[param_id] = param
...@@ -505,28 +534,30 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -505,28 +534,30 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups # Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename() param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None: if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ raise RuntimeError(
Lacking param group file under current directory.') f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path) saved_groups = torch.load(param_group_path)
updated_groups = [] updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group # obtain updated param group
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
optimizer.optim.__dict__.update({'param_groups': updated_groups}) optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. # Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded. # Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set() loaded_file = set()
for pg in optimizer.optim.param_groups: for pg in optimizer.optim.param_groups:
for param in pg['params']: for param in pg["params"]:
if param is None: if param is None:
continue continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
...@@ -550,12 +581,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -550,12 +581,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
working_param = self.master_to_working_map[id(param)] working_param = self.master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
original_shape = optimizer.param_info['param2shape'][id(working_param)] original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(state, sharded_state = self.shard_from_complete_optimizer_state(
current_shape=working_param.shape, state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
original_shape=original_shape, )
device=device,
inplace=True)
optimizer.optim.state[param] = sharded_state optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim) sharded_optimizer_loading_epilogue(optimizer.optim)
...@@ -585,8 +614,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -585,8 +614,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], def link_master_and_working_param(
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): self,
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
):
""" """
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used. This mapping can only be created when mixied precision is used.
...@@ -604,7 +636,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -604,7 +636,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.working_to_master_map[k] = v self.working_to_master_map[k] = v
else: else:
raise ValueError( raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
self.master_to_working_map = dict() self.master_to_working_map = dict()
for k, v in master_to_working_map.items(): for k, v in master_to_working_map.items():
...@@ -614,12 +647,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -614,12 +647,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.master_to_working_map[k] = v self.master_to_working_map[k] = v
else: else:
raise ValueError( raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
@staticmethod @staticmethod
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, def gather_from_sharded_optimizer_state(
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, state: OrderedDict,
inplace: bool) -> OrderedDict: param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
) -> OrderedDict:
""" """
With given parameter and its optimizer states, gather the complete optimizer state for saving. With given parameter and its optimizer states, gather the complete optimizer state for saving.
...@@ -641,14 +681,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -641,14 +681,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state) state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items(): for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step': if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards. # First gather Zero shards.
if use_zero: if use_zero:
v = v.cuda() v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group) dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards. # Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
...@@ -661,9 +700,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -661,9 +700,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return state_ return state_
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, def shard_from_complete_optimizer_state(
original_shape: torch.Size, device: torch.device, self,
inplace: bool) -> OrderedDict: state: OrderedDict,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
""" """
With complete optimizer states of a specific parameter loaded from checkpoint, With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device. slice out the sharded optimizer states kept by current device.
...@@ -681,8 +725,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -681,8 +725,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state) state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items(): for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step': if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group. # Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None: if partition_dim is not None:
......
...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union ...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint from .utils import is_dtensor_checkpoint
__all__ = ['CheckpointIndexFile'] __all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile: class CheckpointIndexFile:
...@@ -50,7 +50,7 @@ class CheckpointIndexFile: ...@@ -50,7 +50,7 @@ class CheckpointIndexFile:
json_path (str): path to the json file. json_path (str): path to the json file.
""" """
# load the json file # load the json file
with open(json_path, 'r') as f: with open(json_path, "r") as f:
index = json.load(f) index = json.load(f)
# assign attributes if exists # assign attributes if exists
...@@ -75,7 +75,7 @@ class CheckpointIndexFile: ...@@ -75,7 +75,7 @@ class CheckpointIndexFile:
index["weight_map"] = self.weight_map index["weight_map"] = self.weight_map
# export the index file # export the index file
with open(json_path, 'w') as f: with open(json_path, "w") as f:
json.dump(index, f, indent=4) json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str): def append_weight_map(self, param_name: str, shard_file: str):
......
# coding=utf-8 # coding=utf-8
import copy
import os import os
import re import re
from collections import abc as container_abcs from collections import abc as container_abcs
...@@ -12,7 +11,7 @@ import torch ...@@ -12,7 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
...@@ -55,7 +54,6 @@ def is_safetensors_available() -> bool: ...@@ -55,7 +54,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available. bool: whether safetensors is available.
""" """
try: try:
import safetensors
return True return True
except ImportError: except ImportError:
return False return False
...@@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: ...@@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns: Returns:
bool: whether the checkpoint file is a dtensor checkpoint. bool: whether the checkpoint file is a dtensor checkpoint.
""" """
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True return True
else: else:
return False return False
...@@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: ...@@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns: Returns:
bool: whether the checkpoint file is a safetensor checkpoint. bool: whether the checkpoint file is a safetensor checkpoint.
""" """
if checkpoint_file_path.endswith('.safetensors'): if checkpoint_file_path.endswith(".safetensors"):
return True return True
else: else:
return False return False
...@@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz ...@@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
partition_dim = dim partition_dim = dim
break break
if partition_dim is not None: if partition_dim is not None:
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ assert (
f"The parameter isn't evenly distributed among tensor parallel group: \ original_shape[partition_dim] == tp_size * current_shape[partition_dim]
), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}" shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim return partition_dim
...@@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz ...@@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# Helper classes and functions for saving shard file # Helper classes and functions for saving shard file
# ====================================== # ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper): def unwrap_optimizer(optimizer: OptimizerWrapper):
''' """
Unwrap a wrapped optimizer. Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints. This method should be used before saving/loading it to/from sharded checkpoints.
''' """
unwrapped_optim = optimizer.optim unwrapped_optim = optimizer.optim
return unwrapped_optim return unwrapped_optim
class StateDictSharder: class StateDictSharder:
def __init__(self, size_per_shard: int) -> None: def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard self.max_shard_size = size_per_shard
self.current_block = OrderedDict() self.current_block = OrderedDict()
self.current_block_size = 0 self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor) tensor_size = calculate_tensor_size(tensor)
ret_block = None ret_block = None
ret_block_size = 0 ret_block_size = 0
...@@ -159,13 +156,11 @@ class StateDictSharder: ...@@ -159,13 +156,11 @@ class StateDictSharder:
return ret_block, ret_block_size return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
# A state might contain more than one tensors. # A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0 state_size = 0
isDTensor = False isDTensor = False
for state_tensor in state.values(): for state_tensor in state.values():
# When state_tensor is not of Tensor class, # When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state # e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error. # The calculation of tensor size should be skipped to avoid error.
...@@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to ...@@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
return param_ return param_
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], def save_state_dict_shards(
checkpoint: str, sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
index_file: "CheckpointIndexFile", checkpoint: str,
base_filename: str, index_file: "CheckpointIndexFile",
is_master: bool, base_filename: str,
use_safetensors: bool = False, is_master: bool,
use_pp_format: bool = False) -> int: use_safetensors: bool = False,
''' use_pp_format: bool = False,
) -> int:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args: Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
...@@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] ...@@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
Returns: Returns:
int: the total size of shards int: the total size of shards
''' """
total_size = 0 total_size = 0
shard_filenames = [] shard_filenames = []
...@@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> ...@@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
""" """
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function. # Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state'] states = state_dict["state"]
state_dict_sharder = StateDictSharder(max_shard_size) state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items(): for param_id, state in states.items():
...@@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors ...@@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
""" """
if use_safetensors: if use_safetensors:
assert is_safetensors_available(), "safetensors is not available." assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \ assert checkpoint_file_path.endswith(
"safetensors only supports .safetensors suffix for checkpoint file." ".safetensors"
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else: else:
torch.save(state_dict, checkpoint_file_path) torch.save(state_dict, checkpoint_file_path)
...@@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: ...@@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path) torch.save(param_groups, group_file_path)
def clean_folder(checkpoint_path: str, def clean_folder(
weights_name: str, checkpoint_path: str,
shard_filenames: List[str], weights_name: str,
is_master: bool = True, shard_filenames: List[str],
use_pp_format: bool = False): is_master: bool = True,
use_pp_format: bool = False,
):
""" """
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
...@@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str, ...@@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str,
else: else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different. # When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) if (
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shard_filenames
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename) os.remove(full_filename)
...@@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi ...@@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB. size_per_shard (int): size per shard in MB.
""" """
root_path = index_file.root_path root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor') output_root_path = root_path.joinpath("dtensor")
# create directory # create directory
output_root_path.mkdir(exist_ok=True) output_root_path.mkdir(exist_ok=True)
...@@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi ...@@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map # update the weight map
# * means all shards # * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map) index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
...@@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str: ...@@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix. str: checkpoint file suffix.
""" """
if use_safetensors: if use_safetensors:
return '.safetensors' return ".safetensors"
else: else:
return '.bin' return ".bin"
def generate_checkpoint_shard_file_name(index: int, def generate_checkpoint_shard_file_name(
total_number: int, index: int, total_number: int, use_safetensors: bool, prefix: str = None
use_safetensors: bool, ) -> str:
prefix: str = None) -> str:
""" """
Generate checkpoint shard file name. Generate checkpoint shard file name.
...@@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo ...@@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name. str: dtensor file name.
""" """
suffix = get_checkpoint_file_suffix(use_safetensors) suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}' return f"{param_name}.{index}.{suffix}"
# ======================================== # ========================================
...@@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): ...@@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
if use_safetensors: if use_safetensors:
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f: with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata["format"] != "pt": if metadata["format"] != "pt":
raise NotImplementedError( raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file) return safe_load_file(checkpoint_file)
else: else:
return torch.load(checkpoint_file, map_location=torch.device('cpu')) return torch.load(checkpoint_file, map_location=torch.device("cpu"))
def load_state_dict_into_model(model: nn.Module, def load_state_dict_into_model(
state_dict: torch.Tensor, model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
missing_keys: List, ):
strict: bool = False,
load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. this module and its descendants.
...@@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module, ...@@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module,
error_msgs: List[str] = [] error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None) metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict) state_dict = OrderedDict(state_dict)
if metadata is not None: if metadata is not None:
state_dict._metadata = metadata state_dict._metadata = metadata
...@@ -560,10 +564,12 @@ def load_state_dict_into_model(model: nn.Module, ...@@ -560,10 +564,12 @@ def load_state_dict_into_model(model: nn.Module,
if strict: if strict:
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
'"{}"'.format(k) for k in unexpected_keys)) ", ".join('"{}"'.format(k) for k in unexpected_keys)
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( )
model.__class__.__name__, "\n\t".join(error_msgs))) raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
...@@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str ...@@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path. # Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices. # The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu')) saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
if not isinstance(saved_groups, List): if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
# The params in param_groups are in the form of pytorch tensors. # The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch. # For more details, please view source code of Optimizer class in pytorch.
...@@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str ...@@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Check the compatibility of saved_groups and param_groups. # Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups): if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups") raise ValueError("loaded state dict has a different number of original parameter groups")
param_lens = (len(g['params']) for g in param_groups) param_lens = (len(g["params"]) for g in param_groups)
saved_lens = (len(g['params']) for g in saved_groups) saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group " raise ValueError(
"that doesn't match the size of optimizer's group") "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)
# Creating mapping from id to parameters. # Creating mapping from id to parameters.
id_map = { id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups old_id: p
)), chain.from_iterable((g['params'] for g in param_groups))) for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in param_groups)),
)
} }
# Update parameter groups, setting their 'params' value. # Update parameter groups, setting their 'params' value.
def update_group(group, new_group): def update_group(group, new_group):
new_group['params'] = group['params'] new_group["params"] = group["params"]
return new_group return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
optimizer.__dict__.update({'param_groups': updated_groups}) optimizer.__dict__.update({"param_groups": updated_groups})
return id_map return id_map
...@@ -628,7 +638,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d ...@@ -628,7 +638,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
# Floating-point types are a bit special here. They are the only ones # Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params. # that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if (key != "step"): if key != "step":
if param.is_floating_point(): if param.is_floating_point():
value = value.to(param.dtype) value = value.to(param.dtype)
value = value.to(param.device) value = value.to(param.device)
...@@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): ...@@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
""" """
# Do the cleaning up as in src code of Pytorch. # Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False) optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
...@@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: ...@@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None return False, None
elif checkpoint_path.is_dir(): elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory # check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.*json')) index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one # if we found a .index.json file, make sure there is only one
if len(index_files) > 0: if len(index_files) > 0:
assert len( assert (
index_files len(index_files) == 1
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1: if len(index_files) == 1:
return True, index_files[0] return True, index_files[0]
else: else:
return False, None return False, None
else: else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.') raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path): def load_state_dict(checkpoint_file_path: Path):
...@@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path): ...@@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict. dict: state dict.
""" """
assert not is_dtensor_checkpoint(checkpoint_file_path), \ assert not is_dtensor_checkpoint(
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' checkpoint_file_path
), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path): if is_safetensor_checkpoint(checkpoint_file_path):
assert is_safetensors_available(), \ assert (
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' is_safetensors_available()
), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors # load with safetensors
from safetensors import safe_open from safetensors import safe_open
state_dict = {} state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
...@@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path): ...@@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path):
else: else:
# load with torch # load with torch
return torch.load(checkpoint_file_path, map_location=torch.device('cpu')) return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
......
from .cli import cli from .cli import cli
__all__ = ['cli'] __all__ = ["cli"]
import click import click
from .check_installation import check_installation from .check_installation import check_installation
__all__ = ['check'] __all__ = ["check"]
@click.command(help="Check if Colossal-AI is correct based on the given option") @click.command(help="Check if Colossal-AI is correct based on the given option")
@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly") @click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly")
def check(installation): def check(installation):
if installation: if installation:
check_installation() check_installation()
......
...@@ -9,7 +9,7 @@ import colossalai ...@@ -9,7 +9,7 @@ import colossalai
def to_click_output(val): def to_click_output(val):
# installation check output to understandable symbols for readability # installation check output to understandable symbols for readability
VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'} VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"}
if val in VAL_TO_SYMBOL: if val in VAL_TO_SYMBOL:
return VAL_TO_SYMBOL[val] return VAL_TO_SYMBOL[val]
...@@ -55,8 +55,8 @@ def check_installation(): ...@@ -55,8 +55,8 @@ def check_installation():
else: else:
torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required]) torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
click.echo(f'#### Installation Report ####') click.echo(f"#### Installation Report ####")
click.echo(f'\n------------ Environment ------------') click.echo(f"\n------------ Environment ------------")
click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}") click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}")
click.echo(f"PyTorch version: {to_click_output(torch_version)}") click.echo(f"PyTorch version: {to_click_output(torch_version)}")
click.echo(f"System CUDA version: {to_click_output(cuda_version)}") click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
...@@ -69,7 +69,7 @@ def check_installation(): ...@@ -69,7 +69,7 @@ def check_installation():
f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version." f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
) )
click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------")
click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}") click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}") click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
...@@ -81,7 +81,7 @@ def check_installation(): ...@@ -81,7 +81,7 @@ def check_installation():
click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
click.echo(f"\n------------ Compatibility ------------") click.echo(f"\n------------ Compatibility ------------")
click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}') click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}")
click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}") click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}") click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
click.echo(f"") click.echo(f"")
...@@ -106,12 +106,12 @@ def _is_compatible(versions): ...@@ -106,12 +106,12 @@ def _is_compatible(versions):
return False return False
# split version into [major, minor, patch] # split version into [major, minor, patch]
versions = [version.split('.') for version in versions] versions = [version.split(".") for version in versions]
for version in versions: for version in versions:
if len(version) == 2: if len(version) == 2:
# x means unknown # x means unknown
version.append('x') version.append("x")
for idx, version_values in enumerate(zip(*versions)): for idx, version_values in enumerate(zip(*versions)):
equal = len(set(version_values)) == 1 equal = len(set(version_values)) == 1
...@@ -137,11 +137,11 @@ def _parse_colossalai_version(): ...@@ -137,11 +137,11 @@ def _parse_colossalai_version():
# 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions) # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
# 2. X.X.X (when colossalai is not installed with CUDA extensions) # 2. X.X.X (when colossalai is not installed with CUDA extensions)
# where X represents an integer. # where X represents an integer.
colossalai_version = colossalai.__version__.split('+')[0] colossalai_version = colossalai.__version__.split("+")[0]
try: try:
torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0] torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0]
cuda_version_for_aot_build = colossalai.__version__.split('cu')[1] cuda_version_for_aot_build = colossalai.__version__.split("cu")[1]
except: except:
torch_version_for_aot_build = None torch_version_for_aot_build = None
cuda_version_for_aot_build = None cuda_version_for_aot_build = None
...@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed(): ...@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed():
JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime. JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
""" """
try: try:
import colossalai._C.fused_optim
found_aot_cuda_ext = True found_aot_cuda_ext = True
except ImportError: except ImportError:
found_aot_cuda_ext = False found_aot_cuda_ext = False
...@@ -175,14 +174,14 @@ def _check_torch_version(): ...@@ -175,14 +174,14 @@ def _check_torch_version():
# torch version can be of two formats # torch version can be of two formats
# - 1.13.1+cu113 # - 1.13.1+cu113
# - 1.13.1.devxxx # - 1.13.1.devxxx
torch_version = torch.__version__.split('+')[0] torch_version = torch.__version__.split("+")[0]
torch_version = '.'.join(torch_version.split('.')[:3]) torch_version = ".".join(torch_version.split(".")[:3])
# get cuda version in pytorch build # get cuda version in pytorch build
try: try:
torch_cuda_major = torch.version.cuda.split(".")[0] torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1] torch_cuda_minor = torch.version.cuda.split(".")[1]
torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}"
except: except:
torch_cuda_version = None torch_cuda_version = None
...@@ -208,7 +207,7 @@ def _check_cuda_version(): ...@@ -208,7 +207,7 @@ def _check_cuda_version():
release = output[release_idx].split(".") release = output[release_idx].split(".")
bare_metal_major = release[0] bare_metal_major = release[0]
bare_metal_minor = release[1][0] bare_metal_minor = release[1][0]
cuda_version = f'{bare_metal_major}.{bare_metal_minor}' cuda_version = f"{bare_metal_major}.{bare_metal_minor}"
except: except:
cuda_version = None cuda_version = None
return cuda_version return cuda_version
...@@ -4,8 +4,7 @@ from .check import check ...@@ -4,8 +4,7 @@ from .check import check
from .launcher import run from .launcher import run
class Arguments(): class Arguments:
def __init__(self, arg_dict): def __init__(self, arg_dict):
for k, v in arg_dict.items(): for k, v in arg_dict.items():
self.__dict__[k] = v self.__dict__[k] = v
...@@ -19,5 +18,5 @@ def cli(): ...@@ -19,5 +18,5 @@ def cli():
cli.add_command(run) cli.add_command(run)
cli.add_command(check) cli.add_command(check)
if __name__ == '__main__': if __name__ == "__main__":
cli() cli()
...@@ -5,56 +5,81 @@ from colossalai.context import Config ...@@ -5,56 +5,81 @@ from colossalai.context import Config
from .run import launch_multi_processes from .run import launch_multi_processes
@click.command(help="Launch distributed training on a single node or multiple nodes", @click.command(
context_settings=dict(ignore_unknown_options=True)) help="Launch distributed training on a single node or multiple nodes",
@click.option("-H", context_settings=dict(ignore_unknown_options=True),
"-host", )
"--host", @click.option(
type=str, "-H",
default=None, "-host",
help="the list of hostnames to launch in the format <host1>,<host2>") "--host",
type=str,
default=None,
help="the list of hostnames to launch in the format <host1>,<host2>",
)
@click.option( @click.option(
"--hostfile", "--hostfile",
type=str, type=str,
default=None, default=None,
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
@click.option("--include", )
type=str, @click.option(
default=None, "--include",
help="Specify computing devices to use during execution. String format is <host1>,<host2>," type=str,
" only effective when used with --hostfile.") default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.",
)
@click.option( @click.option(
"--exclude", "--exclude",
type=str, type=str,
default=None, default=None,
help= help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include," " only effective when used with --hostfile.",
" only effective when used with --hostfile.") )
@click.option("--num_nodes", @click.option(
type=int, "--num_nodes",
default=-1, type=int,
help="Total number of worker nodes to use, only effective when used with --hostfile.") default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.",
)
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") @click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
@click.option("--master_port", @click.option(
type=int, "--master_port",
default=29500, type=int,
help="(optional) Port used by PyTorch distributed for communication during distributed training.") default=29500,
@click.option("--master_addr", help="(optional) Port used by PyTorch distributed for communication during distributed training.",
type=str, )
default="127.0.0.1", @click.option(
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") "--master_addr",
type=str,
default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
)
@click.option( @click.option(
"--extra_launch_args", "--extra_launch_args",
type=str, type=str,
default=None, default=None,
help= help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
"Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " "This will be converted to --arg1=1 --arg2=2 during execution",
"This will be converted to --arg1=1 --arg2=2 during execution") )
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str) @click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1) @click.argument("user_args", nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, def run(
master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: host: str,
hostfile: str,
num_nodes: int,
nproc_per_node: int,
include: str,
exclude: str,
master_addr: str,
master_port: int,
extra_launch_args: str,
ssh_port: int,
user_script: str,
user_args: str,
) -> None:
""" """
To launch multiple processes on a single node or multiple nodes via command line. To launch multiple processes on a single node or multiple nodes via command line.
...@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: ...@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
# run with hostfile excluding the hosts selected # run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
""" """
if not user_script.endswith('.py'): if not user_script.endswith(".py"):
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit() exit()
args_dict = locals() args_dict = locals()
......
import socket import socket
from typing import List
class HostInfo: class HostInfo:
...@@ -34,7 +33,7 @@ class HostInfo: ...@@ -34,7 +33,7 @@ class HostInfo:
""" """
if port is None: if port is None:
port = 22 # no port specified, lets just use the ssh port port = 22 # no port specified, lets just use the ssh port
# socket.getfqdn("127.0.0.1") does not return localhost # socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines # on some users' machines
...@@ -50,7 +49,7 @@ class HostInfo: ...@@ -50,7 +49,7 @@ class HostInfo:
return localaddrs == targetaddrs return localaddrs == targetaddrs
def __str__(self): def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}' return f"hostname: {self.hostname}, port: {self.port}"
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
......
...@@ -7,8 +7,13 @@ import fabric ...@@ -7,8 +7,13 @@ import fabric
from .hostinfo import HostInfo, HostInfoList from .hostinfo import HostInfo, HostInfoList
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, def run_on_host(
send_conn: mp_connection.Connection, env: dict) -> None: hostinfo: HostInfo,
workdir: str,
recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection,
env: dict,
) -> None:
""" """
Use fabric connection to execute command on local or remote hosts. Use fabric connection to execute command on local or remote hosts.
...@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne ...@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False finish = False
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
# keep listening until exit # keep listening until exit
while not finish: while not finish:
# receive cmd # receive cmd
cmds = recv_conn.recv() cmds = recv_conn.recv()
if cmds == 'exit': if cmds == "exit":
# exit from the loop # exit from the loop
finish = True finish = True
break break
...@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne ...@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
else: else:
# execute on the remote machine # execute on the remote machine
fab_conn.run(cmds, hide=False) fab_conn.run(cmds, hide=False)
send_conn.send('success') send_conn.send("success")
except Exception as e: except Exception as e:
click.echo( click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}" f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
) )
send_conn.send('failure') send_conn.send("failure")
# shutdown # shutdown
send_conn.send("finish") send_conn.send("finish")
...@@ -96,8 +101,7 @@ class MultiNodeRunner: ...@@ -96,8 +101,7 @@ class MultiNodeRunner:
cmd (str): the command to execute cmd (str): the command to execute
""" """
assert hostinfo.hostname in self.master_send_conns, \ assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
f'{hostinfo} is not found in the current connections'
conn = self.master_send_conns[hostinfo.hostname] conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd) conn.send(cmd)
...@@ -107,7 +111,7 @@ class MultiNodeRunner: ...@@ -107,7 +111,7 @@ class MultiNodeRunner:
""" """
for hostname, conn in self.master_send_conns.items(): for hostname, conn in self.master_send_conns.items():
conn.send('exit') conn.send("exit")
def recv_from_all(self) -> dict: def recv_from_all(self) -> dict:
""" """
......
...@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList ...@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner from .multinode_runner import MultiNodeRunner
# Constants that define our syntax # Constants that define our syntax
NODE_SEP = ',' NODE_SEP = ","
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
...@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: ...@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit() exit()
with open(hostfile_path, 'r') as fd: with open(hostfile_path, "r") as fd:
device_pool = HostInfoList() device_pool = HostInfoList()
for line in fd.readlines(): for line in fd.readlines():
line = line.strip() line = line.strip()
if line == '': if line == "":
# skip empty lines # skip empty lines
continue continue
...@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: ...@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
'''Parse an inclusion or exclusion string and filter a hostfile dictionary. """Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples: Examples:
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
...@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str ...@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
Returns: Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
''' """
# Ensure include/exclude are mutually exclusive # Ensure include/exclude are mutually exclusive
if include_str and exclude_str: if include_str and exclude_str:
...@@ -136,16 +136,16 @@ def get_launch_command( ...@@ -136,16 +136,16 @@ def get_launch_command(
for k, v in arg_dict.items(): for k, v in arg_dict.items():
if v: if v:
ret.append(f'--{k}={v}') ret.append(f"--{k}={v}")
else: else:
ret.append(f'--{k}') ret.append(f"--{k}")
return ret return ret
if extra_launch_args: if extra_launch_args:
extra_launch_args_dict = dict() extra_launch_args_dict = dict()
for arg in extra_launch_args.split(','): for arg in extra_launch_args.split(","):
if '=' in arg: if "=" in arg:
k, v = arg.split('=') k, v = arg.split("=")
extra_launch_args_dict[k] = v extra_launch_args_dict[k] = v
else: else:
extra_launch_args_dict[arg] = None extra_launch_args_dict[arg] = None
...@@ -158,9 +158,14 @@ def get_launch_command( ...@@ -158,9 +158,14 @@ def get_launch_command(
if torch_version.minor < 9: if torch_version.minor < 9:
cmd = [ cmd = [
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", sys.executable,
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", "-m",
f"--node_rank={node_rank}" "torch.distributed.launch",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={master_addr}",
f"--master_port={master_port}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
] ]
else: else:
# extra launch args for torch distributed launcher with torch >= 1.9 # extra launch args for torch distributed launcher with torch >= 1.9
...@@ -174,17 +179,24 @@ def get_launch_command( ...@@ -174,17 +179,24 @@ def get_launch_command(
if torch_version.minor < 10: if torch_version.minor < 10:
cmd = [ cmd = [
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", sys.executable,
f"--nnodes={num_nodes}", f"--node_rank={node_rank}" "-m",
"torch.distributed.run",
f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
] ]
else: else:
cmd = [ cmd = [
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" "torchrun",
f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}",
f"--node_rank={node_rank}",
] ]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args) cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd = ' '.join(cmd) cmd = " ".join(cmd)
return cmd return cmd
...@@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None: ...@@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None:
# run on local node if not hosts or hostfile is given # run on local node if not hosts or hostfile is given
# add local node to host info list # add local node to host info list
active_device_pool = HostInfoList() active_device_pool = HostInfoList()
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
active_device_pool.append(localhost_info) active_device_pool.append(localhost_info)
# launch distributed processes # launch distributed processes
runner = MultiNodeRunner() runner = MultiNodeRunner()
curr_path = os.path.abspath('.') curr_path = os.path.abspath(".")
# collect current path env # collect current path env
env = dict() env = dict()
for k, v in os.environ.items(): for k, v in os.environ.items():
# do not support multi-line env var # do not support multi-line env var
if v and '\n' not in v: if v and "\n" not in v:
env[k] = v env[k] = v
# establish remote connection # establish remote connection
...@@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None: ...@@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None:
# execute distributed launching command # execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool): for node_id, hostinfo in enumerate(active_device_pool):
cmd = get_launch_command(master_addr=args.master_addr, cmd = get_launch_command(
master_port=args.master_port, master_addr=args.master_addr,
nproc_per_node=args.nproc_per_node, master_port=args.master_port,
user_script=args.user_script, nproc_per_node=args.nproc_per_node,
user_args=args.user_args, user_script=args.user_script,
node_rank=node_id, user_args=args.user_args,
num_nodes=len(active_device_pool), node_rank=node_id,
extra_launch_args=args.extra_launch_args) num_nodes=len(active_device_pool),
extra_launch_args=args.extra_launch_args,
)
runner.send(hostinfo=hostinfo, cmd=cmd) runner.send(hostinfo=hostinfo, cmd=cmd)
# start training # start training
......
...@@ -3,4 +3,4 @@ from .dist_coordinator import DistCoordinator ...@@ -3,4 +3,4 @@ from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager from .process_group_manager import ProcessGroupManager
from .process_group_mesh import ProcessGroupMesh from .process_group_mesh import ProcessGroupMesh
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh'] __all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
...@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh
@dataclass @dataclass
class DeviceMeshInfo: class DeviceMeshInfo:
''' """
This class is used to store the information used to initialize the device mesh. This class is used to store the information used to initialize the device mesh.
Args: Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
''' """
physical_ids: List[int] physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
...@@ -24,16 +25,18 @@ class DeviceMeshInfo: ...@@ -24,16 +25,18 @@ class DeviceMeshInfo:
if self.mesh_shape is not None: if self.mesh_shape is not None:
world_size = len(self.physical_ids) world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel() mesh_shape_numel = torch.Size(self.mesh_shape).numel()
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' assert (
world_size == mesh_shape_numel
), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
''' """
This method is used to initialize the device mesh. This method is used to initialize the device mesh.
Args: Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
''' """
# parse the device mesh info # parse the device mesh info
physical_devices = device_mesh_info.physical_ids physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices) physical_mesh = torch.tensor(physical_devices)
...@@ -67,13 +70,13 @@ class DeviceMeshManager: ...@@ -67,13 +70,13 @@ class DeviceMeshManager:
Args: Args:
name (str): name of the device mesh name (str): name of the device mesh
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
""" """
if name not in self.device_mesh_store: if name not in self.device_mesh_store:
device_mesh = initialize_device_mesh(device_mesh_info) device_mesh = initialize_device_mesh(device_mesh_info)
self.device_mesh_store[name] = device_mesh self.device_mesh_store[name] = device_mesh
return device_mesh return device_mesh
else: else:
raise ValueError(f'Device mesh {name} already exists.') raise ValueError(f"Device mesh {name} already exists.")
def get(self, name: str) -> DeviceMesh: def get(self, name: str) -> DeviceMesh:
""" """
...@@ -88,7 +91,7 @@ class DeviceMeshManager: ...@@ -88,7 +91,7 @@ class DeviceMeshManager:
if name in self.device_mesh_store: if name in self.device_mesh_store:
return self.device_mesh_store[name] return self.device_mesh_store[name]
else: else:
raise ValueError(f'Device mesh {name} does not exist.') raise ValueError(f"Device mesh {name} does not exist.")
def destroy(self, name: str) -> None: def destroy(self, name: str) -> None:
""" """
...@@ -103,7 +106,7 @@ class DeviceMeshManager: ...@@ -103,7 +106,7 @@ class DeviceMeshManager:
dist.destroy_process_group(pg) dist.destroy_process_group(pg)
del self.device_mesh_store[name] del self.device_mesh_store[name]
else: else:
raise ValueError(f'Device mesh {name} does not exist.') raise ValueError(f"Device mesh {name} does not exist.")
def destroy_all(self): def destroy_all(self):
""" """
......
...@@ -36,12 +36,13 @@ class DistCoordinator(metaclass=SingletonMeta): ...@@ -36,12 +36,13 @@ class DistCoordinator(metaclass=SingletonMeta):
""" """
def __init__(self): def __init__(self):
assert dist.is_initialized( assert (
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.' dist.is_initialized()
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank() self._rank = dist.get_rank()
self._world_size = dist.get_world_size() self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun # this is often passed by launchers such as torchrun
self._local_rank = os.environ.get('LOCAL_RANK', -1) self._local_rank = os.environ.get("LOCAL_RANK", -1)
@property @property
def rank(self) -> int: def rank(self) -> int:
...@@ -59,7 +60,9 @@ class DistCoordinator(metaclass=SingletonMeta): ...@@ -59,7 +60,9 @@ class DistCoordinator(metaclass=SingletonMeta):
""" """
Assert that the local rank is set. This is often passed by launchers such as torchrun. Assert that the local rank is set. This is often passed by launchers such as torchrun.
""" """
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.' assert (
self.local_rank >= 0
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool: def is_master(self, process_group: ProcessGroup = None) -> bool:
""" """
...@@ -183,7 +186,6 @@ class DistCoordinator(metaclass=SingletonMeta): ...@@ -183,7 +186,6 @@ class DistCoordinator(metaclass=SingletonMeta):
# define an inner function # define an inner function
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if is_master: if is_master:
......
...@@ -19,7 +19,7 @@ class ProcessGroupManager: ...@@ -19,7 +19,7 @@ class ProcessGroupManager:
def __init__(self): def __init__(self):
self.pg_store = dict() self.pg_store = dict()
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
""" """
Get a process group by name. If the process group does not exist, it will be created. Get a process group by name. If the process group does not exist, it will be created.
...@@ -36,7 +36,7 @@ class ProcessGroupManager: ...@@ -36,7 +36,7 @@ class ProcessGroupManager:
self.pg_store[name] = pg self.pg_store[name] = pg
return pg return pg
else: else:
raise ValueError(f'Process group {name} already exists.') raise ValueError(f"Process group {name} already exists.")
def get(self, name: str) -> ProcessGroup: def get(self, name: str) -> ProcessGroup:
""" """
...@@ -51,7 +51,7 @@ class ProcessGroupManager: ...@@ -51,7 +51,7 @@ class ProcessGroupManager:
if name in self.pg_store: if name in self.pg_store:
return self.pg_store[name] return self.pg_store[name]
else: else:
raise ValueError(f'Process group {name} does not exist.') raise ValueError(f"Process group {name} does not exist.")
def destroy(self, name: str) -> None: def destroy(self, name: str) -> None:
""" """
...@@ -64,7 +64,7 @@ class ProcessGroupManager: ...@@ -64,7 +64,7 @@ class ProcessGroupManager:
dist.destroy_process_group(self.pg_store[name]) dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name] del self.pg_store[name]
else: else:
raise ValueError(f'Process group {name} does not exist.') raise ValueError(f"Process group {name} does not exist.")
def destroy_all(self) -> None: def destroy_all(self) -> None:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment