Unverified Commit a39a5c66 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge branch 'main' into feature/shardformer

parents e79b1e80 aaeb520c
...@@ -9,13 +9,15 @@ from coati.models.bloom import BLOOMActor ...@@ -9,13 +9,15 @@ from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor from coati.models.opt import OPTActor
from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler from transformers.trainer import get_scheduler
...@@ -58,6 +60,8 @@ def train(args): ...@@ -58,6 +60,8 @@ def train(args):
model = LlamaActor(pretrained=args.pretrain, model = LlamaActor(pretrained=args.pretrain,
lora_rank=args.lora_rank, lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint) checkpoint=args.grad_checkpoint)
elif args.model == 'chatglm':
model = ChatGLMActor(pretrained=args.pretrain)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -81,6 +85,9 @@ def train(args): ...@@ -81,6 +85,9 @@ def train(args):
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>' tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
elif args.model == 'chatglm':
tokenizer = ChatGLMTokenizer.from_pretrained(
"THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -99,7 +106,6 @@ def train(args): ...@@ -99,7 +106,6 @@ def train(args):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else: else:
optim = Adam(model.parameters(), lr=args.lr) optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger() logger = get_dist_logger()
# configure dataset # configure dataset
...@@ -185,7 +191,7 @@ if __name__ == '__main__': ...@@ -185,7 +191,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2') default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None) parser.add_argument('--dataset', type=str, default=None)
......
pytest pytest
colossalai==0.3.1
\ No newline at end of file
...@@ -2,7 +2,7 @@ transformers>=4.20.1 ...@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm tqdm
datasets datasets
loralib loralib
colossalai>=0.2.4 colossalai==0.3.1
torch<2.0.0, >=1.12.1 torch<2.0.0, >=1.12.1
langchain langchain
tokenizers tokenizers
......
...@@ -11,7 +11,7 @@ from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDatase ...@@ -11,7 +11,7 @@ from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDatase
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [ SFT_DATASET = [
{ {
"instruction": "instruction":
...@@ -80,6 +80,8 @@ def make_tokenizer(model: str): ...@@ -80,6 +80,8 @@ def make_tokenizer(model: str):
elif model == "llama": elif model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
elif model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
else: else:
raise ValueError(f"Unsupported model '{model}'") raise ValueError(f"Unsupported model '{model}'")
return tokenizer return tokenizer
...@@ -93,13 +95,19 @@ def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokeniz ...@@ -93,13 +95,19 @@ def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokeniz
elif model == "llama": elif model == "llama":
assert input_ids_stripped[0] == tokenizer.bos_token_id assert input_ids_stripped[0] == tokenizer.bos_token_id
input_ids_stripped = input_ids_stripped[1:] input_ids_stripped = input_ids_stripped[1:]
elif model == "chatglm":
assert input_ids_stripped[0] == tokenizer.bos_token_id
assert input_ids_stripped[-1] == tokenizer.eos_token_id
input_ids_stripped = input_ids_stripped[1:-1]
assert torch.all(input_ids_stripped != tokenizer.pad_token_id) assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
assert torch.all(input_ids_stripped != tokenizer.bos_token_id) assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
assert torch.all(input_ids_stripped != tokenizer.eos_token_id) assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
assert input_ids_stripped != tokenizer.sep_token_id assert input_ids_stripped != tokenizer.sep_token_id
assert input_ids_stripped != tokenizer.cls_token_id assert input_ids_stripped != tokenizer.cls_token_id
assert input_ids_stripped != tokenizer.mask_token_id if model == "chatglm":
assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
else:
assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
...@@ -190,7 +198,8 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma ...@@ -190,7 +198,8 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma
assert torch.all(r_mask) assert torch.all(r_mask)
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_length", [32, 1024])
...@@ -211,6 +220,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: ...@@ -211,6 +220,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size:
max_length=max_length) max_length=max_length)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
if isinstance(tokenizer, ChatGLMTokenizer):
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
assert input_ids.shape == labels.shape == torch.Size([max_length])
ignore_mask = labels == IGNORE_INDEX
assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
return
for i in range(max_dataset_size): for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict) assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
...@@ -238,4 +260,7 @@ if __name__ == "__main__": ...@@ -238,4 +260,7 @@ if __name__ == "__main__":
max_datasets_size=8, max_datasets_size=8,
max_length=256) max_length=256)
test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) test_prompt_dataset(model="opt",
max_datasets_size=2,
max_length=128)
...@@ -9,11 +9,12 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic ...@@ -9,11 +9,12 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.generation import generate from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.chatglm import ChatGLMActor
from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32]) @pytest.mark.parametrize("seq_len", [32])
...@@ -24,8 +25,10 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea ...@@ -24,8 +25,10 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea
lambda: GPTActor(), lambda: GPTActor(),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: LlamaActor(), # lambda: LlamaActor(),
lambda: OPTActor() lambda: OPTActor(),
]) # lambda: ChatGLMActor(),
])
@pytest.mark.parametrize("generate_kwargs", [{ @pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64, "max_length": 64,
"use_cache": True, "use_cache": True,
...@@ -115,11 +118,13 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): ...@@ -115,11 +118,13 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int):
lambda: (GPTActor(), GPTCritic(), GPTRM()), lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time # HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()), lambda: (OPTActor(), OPTCritic(), OPTRM()),
]) lambda: (ChatGLMActor(), None, None),
])
@torch.no_grad() @torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
batch_size: int,
seq_len: int):
actor_input = { actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)), "input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)) "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
...@@ -135,20 +140,30 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b ...@@ -135,20 +140,30 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], b
} }
actor, critic, rm = models_maker() actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor):
actor = actor.float()
tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
actor_input ={
"input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
}
assert isinstance(actor, Actor) assert isinstance(actor, Actor)
base_actor_model = get_base_model(actor) base_actor_model = get_base_model(actor)
assert isinstance(critic, Critic)
base_critic_model = get_base_model(critic)
assert isinstance(rm, RewardModel)
base_rm_model = get_base_model(rm)
actor_output = actor(**actor_input) actor_output = actor(**actor_input)
critic_output = critic(**critic_input)
rm_output = rm(**rm_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len) assert actor_output.logits.shape[:2] == (batch_size, seq_len)
assert critic_output.shape == (batch_size,)
assert rm_output.shape == (batch_size,) if critic:
assert isinstance(critic, Critic)
base_critic_model = get_base_model(critic)
critic_output = critic(**critic_input)
assert critic_output.shape == (batch_size, )
if rm:
assert isinstance(rm, RewardModel)
base_rm_model = get_base_model(rm)
rm_output = rm(**rm_input)
assert rm_output.shape == (batch_size, )
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
...@@ -203,4 +218,4 @@ if __name__ == "__main__": ...@@ -203,4 +218,4 @@ if __name__ == "__main__":
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100) test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
...@@ -144,7 +144,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh ...@@ -144,7 +144,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value # DeviceMesh information instructs the scaling of the size value
device_mesh_info = {} device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape): for dim, dim_size in enumerate(device_mesh.shape):
device_mesh_info[dim] = dim_size device_mesh_info[dim] = dim_size
def _extract_target_dim(node): def _extract_target_dim(node):
......
import gc import gc
import logging import logging
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC ...@@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict, load_shard_state_dict,
save_config_file, save_config_file,
save_state_dict, save_state_dict,
...@@ -25,8 +22,7 @@ from colossalai.checkpoint_io.utils import ( ...@@ -25,8 +22,7 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
...@@ -134,11 +130,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -134,11 +130,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
As there is communication when getting state dict, this must be called on all processes. As there is communication when getting state dict, this must be called on all processes.
""" """
# If optimizer is wrapped, unwrap it. assert isinstance(optimizer, GeminiOptimizer)
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
...@@ -185,11 +177,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -185,11 +177,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if not os.path.isfile(checkpoint_index_file): if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file") logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
# If optimizer is wrapped, unwrap it. assert isinstance(optimizer, GeminiOptimizer)
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
# Read checkpoint index file. # Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
...@@ -222,47 +210,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -222,47 +210,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module
class GeminiOptimizer(OptimizerWrapper):
def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
class GeminiPlugin(DPPluginBase): class GeminiPlugin(DPPluginBase):
""" """
Plugin for Gemini. Plugin for Gemini.
...@@ -279,8 +226,20 @@ class GeminiPlugin(DPPluginBase): ...@@ -279,8 +226,20 @@ class GeminiPlugin(DPPluginBase):
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args: Args:
device (torch.device): device to place the model. chunk_config_dict (dict, optional): chunk configuration dictionary.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
...@@ -312,8 +271,14 @@ class GeminiPlugin(DPPluginBase): ...@@ -312,8 +271,14 @@ class GeminiPlugin(DPPluginBase):
def __init__( def __init__(
self, self,
device: Optional[torch.device] = None, chunk_config_dict: Optional[dict] = None,
placement_policy: str = "cpu", chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16", precision: str = "fp16",
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
...@@ -337,8 +302,14 @@ class GeminiPlugin(DPPluginBase): ...@@ -337,8 +302,14 @@ class GeminiPlugin(DPPluginBase):
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict( self.gemini_config = dict(
device=(device or get_current_device()), chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy, placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory, pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32, force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode, strict_ddp_mode=strict_ddp_mode,
...@@ -395,12 +366,15 @@ class GeminiPlugin(DPPluginBase): ...@@ -395,12 +366,15 @@ class GeminiPlugin(DPPluginBase):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose) model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, optimizer = GeminiOptimizer(optimizer,
self.verbose) model.unwrap(),
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -17,8 +17,13 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO ...@@ -17,8 +17,13 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames, get_optimizer_base_filenames,
get_shard_filename, get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_states_into_optimizer,
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
) )
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -126,19 +131,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -126,19 +131,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file_path (str): Path to the index file index_file_path (str): Path to the index file
prefix (str): Not used. prefix (str): Not used.
""" """
super().load_sharded_optimizer(optimizer, index_file_path, prefix) # If optimizer is wrapped, unwrap it.
current_rank_state_dict = optimizer.optim.state_dict()['state'] if isinstance(optimizer, OptimizerWrapper):
for param_idx, state in current_rank_state_dict.items(): optimizer = unwrap_optimizer(optimizer)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step': # Read checkpoint index file.
padding_size = (self.coordinator.world_size - ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad(): # Load param_groups
v = v.flatten() param_group_path = ckpt_index_file.get_param_group_filename()
if padding_size > 0: if param_group_path is None:
v = torch.nn.functional.pad(v, [0, padding_size]) raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
v_list = v.split(v.numel() // self.coordinator.world_size) Lacking param group file under current directory.')
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach() id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
class LowLevelZeroModel(ModelWrapper): class LowLevelZeroModel(ModelWrapper):
......
...@@ -79,8 +79,6 @@ class GeneralCheckpointIO(CheckpointIO): ...@@ -79,8 +79,6 @@ class GeneralCheckpointIO(CheckpointIO):
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map) load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
......
...@@ -514,7 +514,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): ...@@ -514,7 +514,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file) return safe_load_file(checkpoint_file)
else: else:
return torch.load(checkpoint_file) return torch.load(checkpoint_file, map_location=torch.device('cpu'))
def load_state_dict_into_model(model: nn.Module, def load_state_dict_into_model(model: nn.Module,
...@@ -574,7 +574,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str ...@@ -574,7 +574,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path. # Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices. # The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path) saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List): if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
...@@ -730,7 +730,7 @@ def load_state_dict(checkpoint_file_path: Path): ...@@ -730,7 +730,7 @@ def load_state_dict(checkpoint_file_path: Path):
else: else:
# load with torch # load with torch
return torch.load(checkpoint_file_path) return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
......
...@@ -265,6 +265,10 @@ def launch_multi_processes(args: Config) -> None: ...@@ -265,6 +265,10 @@ def launch_multi_processes(args: Config) -> None:
# establish remote connection # establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
# overwrite master addr when num_nodes > 1 and not specified
if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
args.master_addr = active_device_pool.hostinfo_list[0].hostname
# execute distributed launching command # execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool): for node_id, hostinfo in enumerate(active_device_pool):
cmd = get_launch_command(master_addr=args.master_addr, cmd = get_launch_command(master_addr=args.master_addr,
......
...@@ -2,7 +2,13 @@ import warnings ...@@ -2,7 +2,13 @@ import warnings
HAS_MEM_EFF_ATTN = False HAS_MEM_EFF_ATTN = False
try: try:
from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
HAS_MEM_EFF_ATTN = True HAS_MEM_EFF_ATTN = True
except ImportError: except ImportError:
warnings.warn('please install xformers from https://github.com/facebookresearch/xformers') warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
...@@ -16,13 +22,6 @@ if HAS_MEM_EFF_ATTN: ...@@ -16,13 +22,6 @@ if HAS_MEM_EFF_ATTN:
from typing import Optional from typing import Optional
import torch import torch
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
from .utils import SeqLenInfo from .utils import SeqLenInfo
......
...@@ -3,9 +3,15 @@ from typing import Optional ...@@ -3,9 +3,15 @@ from typing import Optional
import torch import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .colo_tensor import _convert_output
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
def is_no_hook_op(func) -> bool:
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS
def filter_colo_parameters(*args, **kwargs): def filter_colo_parameters(*args, **kwargs):
...@@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): ...@@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
""" """
def __new__(cls, def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@property
def shared_param_modules(self):
return self._shared_param_modules
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
def __repr__(self):
return super(ColoParameter, self).__repr__()
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):
if ColoParamOpHookManager.has_hook(): if kwargs is None:
if not func.__name__.startswith('__'): kwargs = {}
if kwargs is None: if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
kwargs = {} params = filter_colo_parameters(*args, **kwargs)
params = filter_colo_parameters(*args, **kwargs) if len(params) > 0:
if len(params) > 0: with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunction(): new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) args, kwargs = replace_args(args, kwargs, new_args)
args, kwargs = replace_args(args, kwargs, new_args) ret = super().__torch_function__(func, types, args, kwargs)
ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunction(): ret = ColoParamOpHookManager.post_op(params, ret)
ret = ColoParamOpHookManager.post_op(params, ret) return _convert_output(ret, func)
return ret
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
...@@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): ...@@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else: else:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
data = self.data.clone() data = self.data.clone()
tensor = ColoParameter(data, tensor = ColoParameter(data, self.requires_grad)
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
memo[id(self)] = tensor memo[id(self)] = tensor
return tensor return tensor
......
import operator from functools import lru_cache
from copy import copy from typing import Callable, Set
from functools import lru_cache, reduce
from typing import Callable, Optional, Set
import torch import torch
from colossalai.tensor.dist_spec_mgr import DistSpecManager INPALCE_MAPPING = {
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec torch.Tensor.add_: torch.Tensor.add,
from colossalai.tensor.process_group import ProcessGroup torch.Tensor.sub_: torch.Tensor.sub,
from colossalai.tensor.tensor_spec import ColoTensorSpec torch.Tensor.mul_: torch.Tensor.mul,
torch.Tensor.div_: torch.Tensor.div
from .const import TensorType }
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None) @lru_cache(None)
...@@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]: ...@@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]:
} }
def _convert_output(output, colo_spec: ColoTensorSpec): def _convert(output):
if type(output) == torch.Tensor: if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
return ColoTensor.from_torch_tensor(output, colo_spec) output.__class__ = ColoTensor
elif isinstance(output, (list, tuple)): elif isinstance(output, (list, tuple)):
return type(output)(_convert_output(o, colo_spec) for o in output) output = type(output)(_convert(o) for o in output)
else: return output
return output
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec: def _convert_output(output, func):
for elem in args: if func in _get_my_nowrap_functions():
if isinstance(elem, ColoTensor): return output
pg = elem.get_process_group() return _convert(output)
dp = elem.dist_spec
return ColoTensorSpec(pg, dp)
elif isinstance(elem, (list, tuple)):
spec = _get_spec_from_args(elem, {})
if spec is not None:
return spec
for k, v in kwargs.items():
if isinstance(v, ColoTensor):
pg = v.get_process_group()
dp = v.dist_spec
return ColoTensorSpec(pg, dp)
return None
class ColoTensor(torch.Tensor): class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
The Colotensor can be initialized with a PyTorch tensor in the following ways. It is only used to trigger the torch function hook.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
""" """
torch_major = int(torch.__version__.split('.')[0]) torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1]) torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
""" """
The signature of the __new__ has to be consistent with the torch.Tensor. The signature of the __new__ has to be consistent with the torch.Tensor.
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns: Returns:
ColoTensor: a ColoTensor wrappers the data. ColoTensor: a ColoTensor wrappers the data.
...@@ -88,86 +61,6 @@ class ColoTensor(torch.Tensor): ...@@ -88,86 +61,6 @@ class ColoTensor(torch.Tensor):
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad) return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = ReplicaSpec()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
if spec.pg is None:
self.process_group = ProcessGroup()
else:
self.process_group = spec.pg
self._type = TensorType.NONMODEL
def has_compute_spec(self) -> bool:
return self.compute_spec is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args:
pg (ProcessGroup): target pg
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg
def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size()
def get_dp_world_size(self) -> int:
"""get_dp_world_size
get the dp world size of the tensor.
Returns:
int: dp world size
"""
return self.process_group.dp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None
self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec is not None:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec is not None:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):
return self.compute_spec.compute_pattern == compute_pattern
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
...@@ -175,9 +68,6 @@ class ColoTensor(torch.Tensor): ...@@ -175,9 +68,6 @@ class ColoTensor(torch.Tensor):
if not all(issubclass(cls, t) for t in types): if not all(issubclass(cls, t) for t in types):
return NotImplemented return NotImplemented
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module # in order to trigger pre-op hook in the forward of checkpoint module
...@@ -189,94 +79,16 @@ class ColoTensor(torch.Tensor): ...@@ -189,94 +79,16 @@ class ColoTensor(torch.Tensor):
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs) return backward_tensor.backward(**tensor_kwargs)
# replace the in-place function
if func in INPALCE_MAPPING:
func = INPALCE_MAPPING[func]
# set the 'inplace' kwargs to False
if 'inplace' in kwargs:
kwargs['inplace'] = False
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions(): return _convert_output(ret, func)
return ret
else:
colo_spec = _get_spec_from_args(args, kwargs)
return _convert_output(ret, colo_spec)
def __repr__(self):
output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
Returns:
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
# if the pg is not equal, convert the current tensor to replicated
handled = self.redistribute(ReplicaSpec())
else:
handled = self
pg = self.process_group
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to ReplicaSpec()
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
"""from_torch_tensor
A static method builds a `ColoTensor` from a PyTorch Tensor.
Args:
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
Returns:
ColoTensor: a ColoTensor
"""
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
if id(self) in memo: if id(self) in memo:
...@@ -284,60 +96,6 @@ class ColoTensor(torch.Tensor): ...@@ -284,60 +96,6 @@ class ColoTensor(torch.Tensor):
else: else:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
data = self.data.clone() data = self.data.clone()
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec))) tensor = ColoTensor(data)
memo[id(self)] = tensor memo[id(self)] = tensor
return tensor return tensor
# override builtin functions which must use tensor in replicate placement #
def size_local(self, *args) -> torch.Size:
with torch._C.DisableTorchFunction():
return super().size(*args)
def size_global(self, *args) -> torch.Size:
"""size_global
override the torch building size()
the shape passed in must be in a replicate placement.
Returns:
torch.Size: the global tensor shape
"""
if self.is_replicate():
return self.size_local(*args)
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list = list(self.size_local())
for dim, num_partition in zip(dims, num_partitions):
size_list[dim] *= num_partition
if args == ():
return torch.Size(size_list)
else:
return size_list[args[0]]
def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated.
"""
return reduce(operator.mul, self.size_global(), 1)
# Some API for dist spec check
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD
...@@ -3,9 +3,7 @@ from contextlib import contextmanager ...@@ -3,9 +3,7 @@ from contextlib import contextmanager
from typing import Any, List, Tuple from typing import Any, List, Tuple
import torch import torch
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.tensor_spec import ColoTensorSpec
class ColoParamOpHook(ABC): class ColoParamOpHook(ABC):
...@@ -82,26 +80,18 @@ class ColoParamOpHookManager: ...@@ -82,26 +80,18 @@ class ColoParamOpHookManager:
@staticmethod @staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list: def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params) ColoParamOpHookManager._trigger_pre_forward(params)
grad_args, rear_args = _get_grad_args(*args) # auto grad function can only recognize torch.Tensor, thus we have to flatten the input
colo_info = _get_colo_tensors_info(*grad_args) # if one of the input requires grad, all the output will be treated as requires grad
rets = PreFwdPostBwd.apply(params, *grad_args) # and will have grad fn even the corresponding input does not require grad
update_args = _update_colo_tensors(colo_info, *rets) # we have to extract tensors requiring grad into flat list and then merge them back
if rear_args is None: grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
return update_args new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
else: return _merge_args(new_grad_args, other_args, grad_flags, spec)
arg_zero = (tuple(update_args),)
return arg_zero + rear_args
@staticmethod @staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any: def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params) ColoParamOpHookManager._trigger_post_forward(params)
colo_info = _get_colo_tensors_info(arg) return PostFwdPreBwd.apply(params, arg)
ret = PostFwdPreBwd.apply(params, arg)
res = _update_colo_tensors(colo_info, ret)
if len(res) == 1:
return res[0]
else:
return res
@staticmethod @staticmethod
def has_hook() -> bool: def has_hook() -> bool:
...@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool: ...@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
return False return False
def _has_grad_tensor(obj) -> bool: def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
if isinstance(obj, tuple) or isinstance(obj, list): flat_args, spec = tree_flatten(args)
for x in obj: grad_args = []
if _has_grad_tensor(x): other_args = []
return True grad_flags = []
return False for arg in flat_args:
elif isinstance(obj, dict): flag = _is_grad_tensor(arg)
for x in obj.values(): grad_flags.append(flag)
if _has_grad_tensor(x): if flag:
return True grad_args.append(arg)
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first argument should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
check_grad_flag = False
for obj in arg_zero:
check_grad_flag |= _is_grad_tensor(obj)
if not check_grad_flag:
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
return arg_zero, args[1:]
def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
else: else:
info.append(None) other_args.append(arg)
return info assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec
def _update_colo_tensors(info, *args) -> list:
ret = [] def _merge_args(grad_args, other_args, grad_flags, spec):
for t_info, arg in zip(info, args): grad_iter = iter(grad_args)
if t_info is not None: other_iter = iter(other_args)
t_cls, spec = t_info flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
arg = t_cls.from_torch_tensor(arg, spec=spec) return tree_unflatten(flat_args, spec)
ret.append(arg)
return ret
...@@ -2,8 +2,7 @@ from .gemini import ( ...@@ -2,8 +2,7 @@ from .gemini import (
ColoInitContext, ColoInitContext,
GeminiAdamOptimizer, GeminiAdamOptimizer,
GeminiDDP, GeminiDDP,
ZeroDDP, GeminiOptimizer,
ZeroOptimizer,
get_static_torch_model, get_static_torch_model,
post_process_colo_init_ctx, post_process_colo_init_ctx,
) )
...@@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer ...@@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper from .wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = [ __all__ = [
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', 'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
] ]
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP from .gemini_ddp import GeminiDDP
from .gemini_mgr import GeminiManager from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
from .utils import get_static_torch_model from .utils import get_static_torch_model
__all__ = [ __all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' 'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
] ]
...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional ...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -55,7 +55,7 @@ class Chunk: ...@@ -55,7 +55,7 @@ class Chunk:
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
process_group: ColoProcessGroup, process_group: ProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False, cpu_shard_init: bool = False,
...@@ -69,7 +69,7 @@ class Chunk: ...@@ -69,7 +69,7 @@ class Chunk:
Args: Args:
chunk_size (int): the number of elements in the chunk chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk process_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU The default value is None, which is the current GPU
...@@ -83,7 +83,7 @@ class Chunk: ...@@ -83,7 +83,7 @@ class Chunk:
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.utilized_size = 0 self.utilized_size = 0
self.torch_pg = process_group.dp_process_group() self.torch_pg = process_group
self.pg_size = dist.get_world_size(self.torch_pg) self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg) self.pg_rank = dist.get_rank(self.torch_pg)
...@@ -218,7 +218,7 @@ class Chunk: ...@@ -218,7 +218,7 @@ class Chunk:
return False return False
else: else:
return self.tensor_state_cnter[TensorState.HOLD] + \ return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property @property
def can_reduce(self): def can_reduce(self):
......
...@@ -2,8 +2,9 @@ from collections import deque ...@@ -2,8 +2,9 @@ from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState from .chunk import Chunk, ChunkFullError, TensorState
...@@ -27,16 +28,17 @@ class ChunkManager: ...@@ -27,16 +28,17 @@ class ChunkManager:
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict() self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set() self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0 self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def register_tensor(self, def register_tensor(self,
tensor: ColoTensor, tensor: torch.Tensor,
group_type: str, group_type: str,
config_key: int, config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False, cpu_offload: bool = False,
pin_memory: bool = False) -> None: pin_memory: bool = False) -> None:
""" """
...@@ -51,7 +53,7 @@ class ChunkManager: ...@@ -51,7 +53,7 @@ class ChunkManager:
pin_memory: whether the chunk is pinned in the cpu memory pin_memory: whether the chunk is pinned in the cpu memory
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
assert config_key in self.dp_degree_chunk_size_dict assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.dp_degree_chunk_size_dict[config_key] chunk_size = self.dp_degree_chunk_size_dict[config_key]
...@@ -73,12 +75,12 @@ class ChunkManager: ...@@ -73,12 +75,12 @@ class ChunkManager:
if tensor.numel() > chunk_size: if tensor.numel() > chunk_size:
chunk_size = tensor.numel() chunk_size = tensor.numel()
dp_size = tensor.get_dp_world_size() dp_size = dist.get_world_size(process_group)
chunk_size = chunk_size + (-chunk_size % dp_size) chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk( chunk = Chunk(
chunk_size=chunk_size, chunk_size=chunk_size,
process_group=tensor.process_group, process_group=process_group,
dtype=tensor.dtype, dtype=tensor.dtype,
cpu_shard_init=cpu_offload, cpu_shard_init=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
...@@ -220,7 +222,7 @@ class ChunkManager: ...@@ -220,7 +222,7 @@ class ChunkManager:
msg.append(f'[{i}] {chunk}\n') msg.append(f'[{i}] {chunk}\n')
return ''.join(msg) return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque: def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
"""Register a chunk group. """Register a chunk group.
""" """
if group_name not in self.chunk_groups: if group_name not in self.chunk_groups:
......
...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
...@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: ...@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc return left + acc
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: def _tensor_numel(local_param: ColoParameter) -> int:
"""_tensor_numel """_tensor_numel
Get the number of elements of a tensor. Get the number of elements of a tensor.
...@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: ...@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
Returns: Returns:
int: the number of elements. int: the number of elements.
""" """
if strict_ddp_flag and type(local_param) is ColoParameter: # TODO(ver217): support dtensor here
return local_param.numel_global() return local_param.numel()
else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator, def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]: process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree """classify_params_by_dp_degree
Classify the parameters by their dp degree Classify the parameters by their dp degree
...@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, ...@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" # assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param): if is_ddp_ignored(param):
continue continue
param_key = dist.get_world_size(process_group)
if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
if param_key not in params_dict: if param_key not in params_dict:
params_dict[param_key] = [] params_dict[param_key] = []
...@@ -119,6 +111,7 @@ def search_chunk_configuration( ...@@ -119,6 +111,7 @@ def search_chunk_configuration(
min_chunk_size_m: float = 32, min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True, filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False, strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration """search_chunk_configuration
...@@ -149,7 +142,7 @@ def search_chunk_configuration( ...@@ -149,7 +142,7 @@ def search_chunk_configuration(
min_chunk_size = round(min_chunk_size_m * 1024**2) min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0 assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) params_dict = classify_params_by_dp_degree(param_order, process_group)
size_lcm = np.lcm.reduce(list(params_dict.keys())) size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict() config_dict: Dict[int, Dict] = dict()
total_param_size = 0 total_param_size = 0
...@@ -157,7 +150,7 @@ def search_chunk_configuration( ...@@ -157,7 +150,7 @@ def search_chunk_configuration(
size_dict: Dict[int, List[int]] = dict() size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict: for dp_degree in params_dict:
params_list = params_dict[dp_degree] params_list = params_dict[dp_degree]
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list] size_list = [_tensor_numel(p) for p in params_list]
group_acc_size = sum(size_list) group_acc_size = sum(size_list)
total_param_size += group_acc_size total_param_size += group_acc_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment