"examples/language/llama2/test_ci.sh" did not exist on "0442f940f021d024ca390485f0cdf0856fe6cb36"
Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
__all__ = ['Plugin']
__all__ = ["Plugin"]
class Plugin(ABC):
@abstractmethod
def supported_devices(self) -> List[str]:
pass
......@@ -38,11 +37,11 @@ class Plugin(ABC):
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method
pass
......@@ -51,11 +50,31 @@ class Plugin(ABC):
"""
Whether the plugin controls the checkpoint io
"""
pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass
@abstractmethod
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
@abstractmethod
def prepare_dataloader(
self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs,
):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
from abc import abstractmethod
from typing import Any, Callable, Iterator, Optional
import torch
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .plugin_base import Plugin
class PipelinePluginBase(Plugin):
@abstractmethod
def execute_pipeline(
self,
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True,
return_outputs: bool = False,
) -> dict:
pass
import random
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .plugin_base import Plugin
from .dp_plugin_base import DPPluginBase
__all__ = ['TorchDDPPlugin']
__all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
Load model from checkpoint.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
......@@ -55,9 +57,67 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
class TorchDDPModel(ModelWrapper):
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: str,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
......@@ -66,20 +126,21 @@ class TorchDDPModel(ModelWrapper):
return self.module.module
class TorchDDPPlugin(Plugin):
class TorchDDPPlugin(DPPluginBase):
"""
Plugin for PyTorch DDP.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchDDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin()
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
model, train_dataset, optimizer, criterion = ...
plugin = TorchDDPPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
......@@ -90,24 +151,24 @@ class TorchDDPPlugin(Plugin):
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""
def __init__(self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False) -> None:
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
def __init__(
self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
def support_no_sync(self) -> bool:
return True
......@@ -116,73 +177,22 @@ class TorchDDPPlugin(Plugin):
return False
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp16_apex', 'bf16', 'fp8']
return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return ["cuda"]
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()
......@@ -192,7 +202,7 @@ class TorchDDPPlugin(Plugin):
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
......@@ -202,3 +212,7 @@ class TorchDDPPlugin(Plugin):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase
__all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
model = model.unwrap()
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class TorchFSDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
def unwrap_model(self) -> nn.Module:
return self.model
class TorchFSDPPlugin(DPPluginBase):
"""
Plugin for PyTorch FSDP.
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchFSDPPlugin
model, train_dataset, optimizer, criterion = ...
plugin = TorchFSDPPlugin()
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__(
self,
process_group: Optional[ProcessGroup] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
def support_no_sync(self) -> bool:
False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ["fp16", "bf16"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ["cuda"]
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union
from typing import Optional
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -12,7 +11,7 @@ from colossalai.interface import ModelWrapper
from .utils import has_index_file
__all__ = ['CheckpointIO']
__all__ = ["CheckpointIO"]
class CheckpointIO(ABC):
......@@ -62,10 +61,9 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
def load_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True) -> Union[nn.Module, ModelWrapper]:
def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
......@@ -84,15 +82,11 @@ class CheckpointIO(ABC):
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
......@@ -100,14 +94,16 @@ class CheckpointIO(ABC):
return origin_model
def save_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
variant: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
def save_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint.
......@@ -130,46 +126,49 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Load optimizer from checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
raise ValueError(f'Cannot find index file in {checkpoint}')
raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path)
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024):
def save_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
......@@ -185,6 +184,7 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:
......@@ -204,7 +204,6 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
......@@ -217,11 +216,17 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
size_per_shard: int, use_safetensors: bool):
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
"""
Save model to sharded checkpoint.
......@@ -233,7 +238,6 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
......@@ -246,14 +250,13 @@ class CheckpointIO(ABC):
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load optimizer from sharded checkpoint.
......@@ -261,9 +264,7 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
......@@ -274,11 +275,11 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to sharded checkpoint.
......@@ -289,7 +290,6 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
......@@ -301,7 +301,6 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
pass
# ============================================
# methods for loading and saving lr scheduler
......
import gc
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import json
import gc
from typing import Optional
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
has_index_file,
load_state_dict,
save_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
shard_checkpoint,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
add_variant
)
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
)
__all__ = ['GeneralCheckpointIO']
__all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO):
"""
Checkpoint IO
"""
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
......@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load sharded optimizer with the given path to index file.
"""
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
def save_sharded_optimizer(
self,
......@@ -59,7 +83,56 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: str,
size_per_shard: int,
):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Offload optimizer states. States are broken into shards within max_shard_size.
state_dict = optimizer.state_dict()
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
self,
......@@ -70,45 +143,59 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
"""
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
# shard checkpoint
state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
# Save the model
for shard_file, shard in shards.items():
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
# save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
load shard model, load model from multiple files
"""
......@@ -118,21 +205,26 @@ class GeneralCheckpointIO(CheckpointIO):
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names()
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()
if strict and len(missing_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class HybridParallelCheckpointIO(GeneralCheckpointIO):
"""
CheckpointIO for Hybrid Parallel Training.
Args:
dp_group (ProcessGroup): Process group along data parallel dimension.
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.dp_rank = dist.get_rank(self.dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0
self.verbose = verbose
self.coordinator = DistCoordinator()
@staticmethod
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Save buffers.
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model = model.unwrap()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for weight, weight_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model state dict to a single file with given checkpointing path.
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model = model.unwrap()
if self.dp_rank != 0:
return
# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
complete_state_dict = dict()
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
"""
Load model from a single file with the given path of checkpoint.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
strict = False
model_before_wrapping = model
model = model.unwrap()
# Load from checkpoint. Since the logic of breaking parameter shards along tp degree
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
# model.load_state_dict can be directly called.
state_dict = load_state_dict(checkpoint)
model.load_state_dict(state_dict, strict=strict)
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
checkpoint (str): Path to save optimizer state_dict.
gather_dtensor (bool): Whether to gather_dtensor, not used.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()
for param, state in optimizer.optim.state.items():
if param is None:
continue
# working param is needed for obtaining correct param_id
master_to_working_map = optimizer.get_master_to_working_map()
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
original_shape = optimizer.param_info["param2shape"][id(working_param)]
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
device=torch.device("cuda"),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(states_list, local_states, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save lr scheduler to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
@staticmethod
def gather_from_sharded_optimizer_state(
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
tp_size = dist.get_world_size(tp_group)
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().to(device)
return state_
def shard_from_complete_optimizer_state(
self,
state: OrderedDict,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
if self.use_zero:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
import json
import os
from collections import OrderedDict
from pathlib import Path
from typing import Any, List, Union
from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint
__all__ = ['CheckpointIndexFile']
__all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile:
......@@ -18,10 +20,12 @@ class CheckpointIndexFile:
>>> index.export('new_index.json')
"""
def __init__(self) -> None:
self.root_path = None
self.metadata: dict = dict()
self.weight_map: dict = dict()
def __init__(self, root_path=None) -> None:
self.root_path = root_path
# use ordered dict to preserve the tensor checkpoint order
self.metadata: Dict = OrderedDict()
self.weight_map: Dict = OrderedDict()
@staticmethod
def from_file(index_path: Union[str, Path]):
......@@ -46,7 +50,7 @@ class CheckpointIndexFile:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
with open(json_path, "r") as f:
index = json.load(f)
# assign attributes if exists
......@@ -71,7 +75,7 @@ class CheckpointIndexFile:
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
with open(json_path, "w") as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
......@@ -107,7 +111,7 @@ class CheckpointIndexFile:
return True
return False
def get_checkpoint_fileanames(self) -> List[str]:
def get_checkpoint_filenames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.
......@@ -148,9 +152,31 @@ class CheckpointIndexFile:
"""
ckpt_path = self.weight_map[param_name]
return ckpt_path
def get_all_param_names(self):
"""
Get all the weight keys.
"""
return list(self.weight_map.keys())
def get_param_group_filename(self) -> Union[str, None]:
"""
Get the file name of param_group file if this is a checkpoint for optimizer.
Returns:
str: param_group file name
"""
filename = self.metadata.get("param_groups", None)
if filename:
return str(self.root_path.joinpath(filename))
else:
return None
def write_index_file(self, save_index_file):
"""
Write index file.
"""
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2) + "\n"
f.write(content)
# coding=utf-8
import os
import re
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
from packaging.version import Version
from torch.optim import Optimizer
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
# General helper functions
# ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
If so, a new shard should be created.
Args:
tenosr (torch.Tensor): the tensor to calculate size for.
tensor (torch.Tensor): the tensor to calculate size for.
Returns:
float: size of the tensor in MB.
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
......@@ -36,7 +54,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available.
"""
try:
import safetensors
return True
except ImportError:
return False
......@@ -52,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True
else:
return False
......@@ -68,136 +85,210 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
if checkpoint_file_path.endswith('.safetensors'):
if checkpoint_file_path.endswith(".safetensors"):
return True
else:
return False
def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
"""
Given the current shape of parameter and the shape of parameter before sharding,
return the dimension along which the parameter is sharded when using tensor parallel.
If tensor parallel is not used, return None.
Args:
current_shape (torch.Size): The current shape of parameter after sharding.
original_shape (torch.Size): The shape of parameter before sharding.
tp_size (int): The size of tp group.
Returns:
Optional[int]: The dimension along which parameter is partitioned.
"""
partition_dim = None
for dim, length in enumerate(original_shape):
if length > current_shape[dim]:
partition_dim = dim
break
if partition_dim is not None:
assert (
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
# ======================================
# Helper functions for saving shard file
# Helper classes and functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
class StateDictSharder:
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
ret_block = None
ret_block_size = 0
# directly return if state is stored as distributed tensor
if isDTensor:
return ret_block, ret_block_size
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[param_id] = state
self.current_block_size += state_size
return ret_block, ret_block_size
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
Gather the complete parameter for saving if passed in param is distributed under tp setting.
Args:
param (torch.Tensor): A model parameter, might be d_tensor.
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
Returns:
torch.Tensor: the complete parameter
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
return to_global(param_)
elif is_customized_distributed_tensor(param_):
return to_global_for_customized_distributed_tensor(param_)
else:
return param_
def save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False,
use_pp_format: bool = False,
) -> int:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
for key, weight in state_dict.items():
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
Returns:
int: the total size of shards
"""
total_size = 0
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
weight_map[key] = shard_file
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
shard_filenames.append(shard_file)
del shard
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
load shard state dict into model
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
state_dict_sharder = StateDictSharder(max_shard_size)
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
for key, weight in state_dict.items():
if not is_distributed_tensor(weight):
block, block_size = state_dict_sharder.append_param(key, weight)
unexpected_keys: List[str] = []
sub_missing_keys: List[str] = []
error_msgs: List[str] = []
if block != None:
yield block, block_size
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
load(model, state_dict, "")
del load
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict["state"]
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
if block != None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# ======================================
# Helper functions for saving state dict
# ======================================
......@@ -214,14 +305,99 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
assert checkpoint_file_path.endswith(
".safetensors"
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def clean_folder(
checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False,
):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
Args:
checkpoint_path (str): Path to the checkpoint directory.
weights_name (str): Decides the prefix of filenames of weight shards.
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
is_master (bool, optional): Whether current rank is main process. Defaults to True.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
"""
if is_master:
for filename in os.listdir(checkpoint_path):
full_filename = os.path.join(checkpoint_path, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
if not use_pp_format:
reg = re.compile(r"(.*?)-\d{5}")
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shard_filenames
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
"""
Save config.json/generation_config.json if model is a Huggingface pretrained model.
This method can only be called when a model is saved in a sharded way.
Args:
model (nn.Module): The model whose config should be saved if it's a huggingface model.
checkpoint_path (str): Path to the checkpoint directory.
is_master (bool): Whether current rank is main process.
"""
try:
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError:
return
if not isinstance(model, PreTrainedModel):
return
model = unwrap_huggingface_model(model)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype(model)
model.config.torch_dtype = str(dtype).split(".")[1]
# Attach architecture to the config
model.config.architectures = [model.__class__.__name__]
# Save the config
if is_master:
model.config.save_pretrained(checkpoint_path)
if model.can_generate():
model.generation_config.save_pretrained(checkpoint_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
......@@ -233,7 +409,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
output_root_path = root_path.joinpath("dtensor")
# create directory
output_root_path.mkdir(exist_ok=True)
......@@ -253,7 +429,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
......@@ -268,15 +444,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
return ".safetensors"
else:
return '.bin'
return ".bin"
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
def generate_checkpoint_shard_file_name(
index: int, total_number: int, use_safetensors: bool, prefix: str = None
) -> str:
"""
Generate checkpoint shard file name.
......@@ -310,39 +485,190 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
return f"{param_name}.{index}.{suffix}"
def save_state_dict_as_shard(
state_dict: dict,
checkpoint_path: str,
index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None,
) -> None:
# ========================================
# Helper functions for loading state dict
# ========================================
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
"""
load shard state dict into model
"""
Save state dict as shard.
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
def load_state_dict_into_model(
model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Args:
state_dict (dict): state dict.
checkpoint_path (str): path to the checkpoint file.
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
# generate the shard name
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
unexpected_keys: List[str] = []
sub_missing_keys: List[str] = []
error_msgs: List[str] = []
# save the shard
save_state_dict(state_dict, str(shard_file_path), use_safetensors)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
# ========================================
# Helper functions for loading state dict
# ========================================
load(model, state_dict, "", load_sub_module)
del load
missing_keys = missing_keys.append(sub_missing_keys)
if strict:
if len(unexpected_keys) > 0:
error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in unexpected_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
"""
Load information of param_groups into an initialized optimizer.
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
if not isinstance(saved_groups, List):
raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups = optimizer.param_groups
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
param_lens = (len(g["params"]) for g in param_groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)
# Creating mapping from id to parameters.
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in param_groups)),
)
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
new_group["params"] = group["params"]
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
optimizer.__dict__.update({"param_groups": updated_groups})
return id_map
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): A mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): A mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
"""
# Ensure that the keys of state_dict are integers.
state_dict = {int(k): v for k, v in state_dict.items()}
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states = defaultdict(dict)
for k, v in state_dict.items():
if k in id_map:
param = id_map[k]
new_states[param] = cast(param, v)
elif not strict:
new_states[k] = v
optimizer.state.update(new_states)
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
r"""Do the cleaning up work after state_dict has been loaded into optimizer
Args:
optimizer(Optimizer): An optimizer object whose state has just been loaded.
"""
# Do the cleaning up as in src code of Pytorch.
if Version(torch.__version__) >= Version("2.0.0"):
optimizer._patch_step_function() # To support multiprocessing pickle/unpickle
else:
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
......@@ -365,18 +691,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.*json'))
index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
assert len(
index_files
) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
assert (
len(index_files) == 1
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
else:
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path):
......@@ -390,14 +718,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict.
"""
assert not is_dtensor_checkpoint(checkpoint_file_path), \
f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
assert not is_dtensor_checkpoint(
checkpoint_file_path
), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path):
assert is_safetensors_available(), \
f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
assert (
is_safetensors_available()
), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors
from safetensors import safe_open
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
......@@ -406,14 +737,51 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)
return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None and len(variant) > 0:
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
if prefix is not None and len(prefix) > 0:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
splits = splits[:-1] + [prefix] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base model weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_prefix(weights_name, prefix)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME
param_group_file = add_prefix(param_group_file, prefix)
return states_name, save_index_file, param_group_file
def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file
from .cli import cli
__all__ = ['cli']
__all__ = ["cli"]
import click
from colossalai.context import Config
from .benchmark import run_benchmark
from .utils import *
__all__ = ['benchmark']
@click.command()
@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
@click.option("-l", "--layers", type=int, default=2)
@click.option("-m",
"--model",
type=click.Choice(['mlp'], case_sensitive=False),
default='mlp',
help="Select the model to benchmark, currently only supports MLP")
def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
layers: int, model: str):
args_dict = locals()
args = Config(args_dict)
run_benchmark(args)
from functools import partial
from typing import Dict, List
import click
import torch.multiprocessing as mp
import colossalai
from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model
from colossalai.context import Config
from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.testing import free_port
from colossalai.utils import MultiTimer
from .models import MLP
def run_benchmark(args: Config) -> None:
"""
Run benchmarking with torch.multiprocessing.
"""
# sanity checks
if args.gpus is None:
click.echo("Error: --num_gpus is not given")
exit()
if args.gpus <= 1:
click.echo("Warning: tensor parallel will be activated with at least 2 devices.")
click.echo("=== Benchmarking Parameters ===")
for k, v in args.items():
click.echo(f'{k}: {v}')
click.echo('')
config_list = find_all_configs(args.gpus)
avail_ports = [free_port() for _ in range(len(config_list))]
run_func = partial(run_dist_profiling,
world_size=args.gpus,
port_list=avail_ports,
config_list=config_list,
hyperparams=args)
mp.spawn(run_func, nprocs=args.gpus)
def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
hyperparams: Config) -> None:
"""
A function executed for profiling, this function should be spawn by torch.multiprocessing.
Args:
rank (int): rank of the process
world_size (int): the number of processes
port_list (List[int]): a list of free ports for initializing distributed networks
config_list (List[Dict]): a list of configuration
hyperparams (Config): the hyperparameters given by the user
"""
# disable logging for clean output
disable_existing_loggers()
logger = get_dist_logger()
logger.set_level('WARNING')
for config, port in zip(config_list, port_list):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
timer = MultiTimer()
# 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size.
if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0:
click.echo(
"1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size."
)
continue
if hyperparams.model == 'mlp':
model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
else:
if gpc.get_global_rank() == 0:
click.echo("Error: Invalid argument for --model")
exit()
data_func = partial(get_batch_data,
dim=hyperparams.dimension,
batch_size=hyperparams.batch_size,
seq_length=hyperparams.seq_len,
mode=config.parallel.tensor.mode)
fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
warmup_steps=hyperparams.warmup_steps,
profile_steps=hyperparams.profile_steps,
data_func=data_func,
timer=timer)
gpc.destroy()
reset_seeds()
if gpc.get_global_rank() == 0:
config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
click.echo(f"=== {config_str} ===")
click.echo(f"Average forward time: {fwd_time}")
click.echo(f"Average backward time: {bwd_time}")
click.echo(f"Max allocated GPU memory: {max_allocated}")
click.echo(f"Max cached GPU memory: {max_cached}\n")
import torch
import colossalai.nn as col_nn
class MLP(torch.nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(col_nn.Linear(dim, dim))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
import math
import time
import torch
from colossalai.utils import MultiTimer
from colossalai.context import ParallelMode, Config
from typing import List, Dict, Tuple, Callable
def get_time_stamp() -> int:
"""
Return the time stamp for profiling.
Returns:
time_stamp (int): the time given by time.time()
"""
torch.cuda.synchronize()
time_stamp = time.time()
return time_stamp
def get_memory_states() -> Tuple[float]:
"""
Return the memory statistics.
Returns:
max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory
"""
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
max_cached = torch.cuda.max_memory_reserved() / (1024**3)
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
return max_allocated, max_cached
def find_all_configs(device_cnt: int) -> List[Dict]:
"""
Find all possible configurations for tensor parallelism
Args:
device_cnt (int): the number of devices
Returns:
config_list (List[Dict]): a list of configurations
"""
def _is_square(num):
# 2D parallel should be implemented with at least 2 devices.
if num <= 1:
return False
return math.floor(math.sqrt(num))**2 == num
def _is_cube(num):
# 3D parallel should be implemented with at least 2 devices.
if num <= 1:
return False
return math.floor(num**(1. / 3.))**3 == num
config_list = []
# add non-parallel config
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
config_list.append(config)
# add 1D config
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
config_list.append(config)
# add 2D config only if device_cnt is a square
if _is_square(device_cnt):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
config_list.append(config)
# check for 2.5D
# iterate over depth
for depth in range(1, device_cnt):
if device_cnt % depth == 0 and _is_square(device_cnt // depth):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
config_list.append(config)
# check for 3D if device_cnt is a cube
if _is_cube(device_cnt):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
config_list.append(config)
config_list = [Config(cfg) for cfg in config_list]
return config_list
def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
timer: MultiTimer) -> Tuple[float]:
"""
Profile the forward and backward of a model
Args:
model (torch.nn.Module): a PyTorch model
warmup_steps (int): the number of steps for warmup
profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording
Returns:
fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second
max_allocated (float): the maximum GPU memory allocated in GB
max_cached (float): the maximum GPU memory cached in GB
"""
def _run_step(data):
timer.start('forward')
out = model(data)
timer.stop('forward', keep_in_history=True)
timer.start('backward')
out.mean().backward()
timer.stop('backward', keep_in_history=True)
data_list = [data_func() for _ in range(warmup_steps)]
for data in data_list:
_run_step(data)
timer.reset('forward')
timer.reset('backward')
for _ in range(profile_steps):
data = data_func()
_run_step(data)
max_allocated, max_cached = get_memory_states()
fwd_time = timer.get_timer('forward').get_history_mean()
bwd_time = timer.get_timer('backward').get_history_mean()
return fwd_time, bwd_time, max_allocated, max_cached
def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
"""
Return a random data of shape (batch_size, seq_length, dim) for profiling.
Args:
dim (int): hidden size
batch_size (int): the number of data samples
seq_length (int): the number of tokens
mode (ParallelMode): Colossal-AI ParallelMode enum
Returns:
data (torch.Tensor): random data
"""
if mode in ['2d', '2.5d']:
batch_size = batch_size // 2
dim = dim // 2
elif mode == '3d':
batch_size = batch_size // 4
dim = dim // 2
data = torch.rand(batch_size, seq_length, dim).cuda()
return data
import click
from .check_installation import check_installation
__all__ = ['check']
__all__ = ["check"]
@click.command(help="Check if Colossal-AI is correct based on the given option")
@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly")
@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly")
def check(installation):
if installation:
check_installation()
......
......@@ -9,7 +9,7 @@ import colossalai
def to_click_output(val):
# installation check output to understandable symbols for readability
VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'}
VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"}
if val in VAL_TO_SYMBOL:
return VAL_TO_SYMBOL[val]
......@@ -31,7 +31,7 @@ def check_installation():
found_aot_cuda_ext = _check_aot_built_cuda_extension_installed()
cuda_version = _check_cuda_version()
torch_version, torch_cuda_version = _check_torch_version()
colossalai_verison, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version()
colossalai_version, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version()
# if cuda_version is None, that means either
# CUDA_HOME is not found, thus cannot compare the version compatibility
......@@ -55,9 +55,9 @@ def check_installation():
else:
torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
click.echo(f'#### Installation Report ####')
click.echo(f'\n------------ Environment ------------')
click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}")
click.echo(f"#### Installation Report ####")
click.echo(f"\n------------ Environment ------------")
click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}")
click.echo(f"PyTorch version: {to_click_output(torch_version)}")
click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}")
......@@ -69,7 +69,7 @@ def check_installation():
f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
)
click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------')
click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------")
click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
......@@ -81,7 +81,7 @@ def check_installation():
click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
click.echo(f"\n------------ Compatibility ------------")
click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}')
click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}")
click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
click.echo(f"")
......@@ -106,12 +106,12 @@ def _is_compatible(versions):
return False
# split version into [major, minor, patch]
versions = [version.split('.') for version in versions]
versions = [version.split(".") for version in versions]
for version in versions:
if len(version) == 2:
# x means unknown
version.append('x')
version.append("x")
for idx, version_values in enumerate(zip(*versions)):
equal = len(set(version_values)) == 1
......@@ -137,15 +137,15 @@ def _parse_colossalai_version():
# 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
# 2. X.X.X (when colossalai is not installed with CUDA extensions)
# where X represents an integer.
colossalai_verison = colossalai.__version__.split('+')[0]
colossalai_version = colossalai.__version__.split("+")[0]
try:
torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0]
cuda_version_for_aot_build = colossalai.__version__.split('cu')[1]
torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0]
cuda_version_for_aot_build = colossalai.__version__.split("cu")[1]
except:
torch_version_for_aot_build = None
cuda_version_for_aot_build = None
return colossalai_verison, torch_version_for_aot_build, cuda_version_for_aot_build
return colossalai_version, torch_version_for_aot_build, cuda_version_for_aot_build
def _check_aot_built_cuda_extension_installed():
......@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed():
JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
"""
try:
import colossalai._C.fused_optim
found_aot_cuda_ext = True
except ImportError:
found_aot_cuda_ext = False
......@@ -175,14 +174,14 @@ def _check_torch_version():
# torch version can be of two formats
# - 1.13.1+cu113
# - 1.13.1.devxxx
torch_version = torch.__version__.split('+')[0]
torch_version = '.'.join(torch_version.split('.')[:3])
torch_version = torch.__version__.split("+")[0]
torch_version = ".".join(torch_version.split(".")[:3])
# get cuda version in pytorch build
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}"
except:
torch_cuda_version = None
......@@ -208,7 +207,7 @@ def _check_cuda_version():
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
cuda_version = f"{bare_metal_major}.{bare_metal_minor}"
except:
cuda_version = None
return cuda_version
import click
from .benchmark import benchmark
from .check import check
from .launcher import run
class Arguments():
class Arguments:
def __init__(self, arg_dict):
for k, v in arg_dict.items():
self.__dict__[k] = v
......@@ -19,7 +17,6 @@ def cli():
cli.add_command(run)
cli.add_command(check)
cli.add_command(benchmark)
if __name__ == '__main__':
if __name__ == "__main__":
cli()
......@@ -5,56 +5,81 @@ from colossalai.context import Config
from .run import launch_multi_processes
@click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True))
@click.option("-H",
"-host",
"--host",
type=str,
default=None,
help="the list of hostnames to launch in the format <host1>,<host2>")
@click.command(
help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True),
)
@click.option(
"-H",
"-host",
"--host",
type=str,
default=None,
help="the list of hostnames to launch in the format <host1>,<host2>",
)
@click.option(
"--hostfile",
type=str,
default=None,
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
@click.option("--include",
type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.")
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
)
@click.option(
"--include",
type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.",
)
@click.option(
"--exclude",
type=str,
default=None,
help=
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
" only effective when used with --hostfile.")
@click.option("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.")
help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
" only effective when used with --hostfile.",
)
@click.option(
"--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.",
)
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
@click.option("--master_port",
type=int,
default=29500,
help="(optional) Port used by PyTorch distributed for communication during distributed training.")
@click.option("--master_addr",
type=str,
default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
@click.option(
"--master_port",
type=int,
default=29500,
help="(optional) Port used by PyTorch distributed for communication during distributed training.",
)
@click.option(
"--master_addr",
type=str,
default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
)
@click.option(
"--extra_launch_args",
type=str,
default=None,
help=
"Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
"This will be converted to --arg1=1 --arg2=2 during execution")
help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
"This will be converted to --arg1=1 --arg2=2 during execution",
)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
@click.argument("user_args", nargs=-1)
def run(
host: str,
hostfile: str,
num_nodes: int,
nproc_per_node: int,
include: str,
exclude: str,
master_addr: str,
master_port: int,
extra_launch_args: str,
ssh_port: int,
user_script: str,
user_args: str,
) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
......@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
if not user_script.endswith('.py'):
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
if not user_script.endswith(".py"):
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
args_dict = locals()
......
import socket
from typing import List
class HostInfo:
......@@ -34,11 +33,11 @@ class HostInfo:
"""
if port is None:
port = 22 # no port specified, lets just use the ssh port
port = 22 # no port specified, lets just use the ssh port
# socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines
# thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0
# thus, we directly return True if hostname is localhost, 127.0.0.1 or 0.0.0.0
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
......@@ -46,14 +45,11 @@ class HostInfo:
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
for (family, socktype, proto, canonname, sockaddr) in localaddrs:
for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
if rsockaddr[0] == sockaddr[0]:
return True
return False
return localaddrs == targetaddrs
def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}'
return f"hostname: {self.hostname}, port: {self.port}"
def __repr__(self):
return self.__str__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment