Unverified Commit a39a5c66 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge branch 'main' into feature/shardformer

parents e79b1e80 aaeb520c
#!/bin/bash
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -g -x -b 16
#!/bin/bash
################
#Load your environments and modules here
################
HOSTFILE=$(realpath hosts.txt)
cd ../..
export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -p gemini_auto -g -x -b 16
import time import time
import torch import torch
import tqdm
import transformers import transformers
from args import parse_benchmark_args
from transformers import AutoConfig, OPTForCausalLM from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import tqdm
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from args import parse_benchmark_args from colossalai.nn.optimizer import HybridAdam
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
...@@ -61,11 +57,11 @@ def main(): ...@@ -61,11 +57,11 @@ def main():
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
else: else:
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Whether to set limit of memory capacity # Whether to set limit of memory capacity
if args.mem_cap > 0: if args.mem_cap > 0:
colo_memory_cap(args.mem_cap) colo_memory_cap(args.mem_cap)
# Build OPT model # Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path) config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config) model = OPTForCausalLM(config=config)
...@@ -81,11 +77,7 @@ def main(): ...@@ -81,11 +77,7 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
...@@ -96,18 +88,18 @@ def main(): ...@@ -96,18 +88,18 @@ def main():
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer) model, optimizer, _, _, _ = booster.boost(model, optimizer)
SEQ_LEN = 1024 SEQ_LEN = 1024
VOCAB_SIZE = 50257 VOCAB_SIZE = 50257
# Start training. # Start training.
logger.info(f"Start testing", ranks=[0]) logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
start_time = time.time() start_time = time.time()
for _ in range(args.max_train_steps): for _ in range(args.max_train_steps):
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
...@@ -119,18 +111,19 @@ def main(): ...@@ -119,18 +111,19 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
progress_bar.update(1) progress_bar.update(1)
# Compute Statistics # Compute Statistics
end_time = time.time() end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, " logger.info(
f"batch size per gpu: {args.batch_size}, " f"Testing finished, "
f"plugin: {args.plugin}, " f"batch size per gpu: {args.batch_size}, "
f"throughput: {throughput}, " f"plugin: {args.plugin}, "
f"maximum memory usage per gpu: {max_mem}.", f"throughput: {throughput}, "
ranks=[0]) f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__": if __name__ == "__main__":
......
import time import time
import torch
import datasets import datasets
import torch
import transformers import transformers
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer from args import parse_demo_args
from transformers import get_linear_schedule_with_warmup from data import NetflixDataset, netflix_collator
from transformers.utils.versions import require_version
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from args import parse_demo_args from colossalai.nn.optimizer import HybridAdam
from data import NetflixDataset, netflix_collator
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
...@@ -30,18 +25,18 @@ def move_to_cuda(batch, device): ...@@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar: for batch in pbar:
# Forward # Forward
optimizer.zero_grad() optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device()) batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(use_cache=False, **batch) outputs = model(use_cache=False, **batch)
loss = outputs['loss'] loss = outputs['loss']
...@@ -72,7 +67,7 @@ def main(): ...@@ -72,7 +67,7 @@ def main():
else: else:
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Build OPT model # Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path) config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
...@@ -88,43 +83,35 @@ def main(): ...@@ -88,43 +83,35 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare tokenizer and dataloader # Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer) dataset = NetflixDataset(tokenizer)
dataloader = plugin.prepare_dataloader(dataset, dataloader = plugin.prepare_dataloader(dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=netflix_collator) collate_fn=netflix_collator)
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
lr=(args.learning_rate * world_size),
weight_decay=args.weight_decay)
# Set lr scheduler # Set lr scheduler
total_steps = len(dataloader) * args.num_epoch total_steps = len(dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps) num_warmup_steps = int(args.warmup_ratio * total_steps)
lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(optimizer,
optimizer, num_warmup_steps=num_warmup_steps,
num_warmup_steps=num_warmup_steps, num_training_steps=len(dataloader) * args.num_epoch)
num_training_steps=len(dataloader) * args.num_epoch
)
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=dataloader, dataloader=dataloader,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
# Start finetuning # Start finetuning
......
import gzip import gzip
import random from contextlib import nullcontext
from functools import partial from functools import partial
from time import time from time import time
...@@ -8,20 +8,17 @@ import torch ...@@ -8,20 +8,17 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from packaging import version
from colossalai.nn import HybridAdam
from palm_pytorch import PaLM from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
# constants # constants
...@@ -44,23 +41,10 @@ def parse_args(): ...@@ -44,23 +41,10 @@ def parse_args():
help="The distributed plan [colossalai, pytorch].", help="The distributed plan [colossalai, pytorch].",
) )
parser.add_argument( parser.add_argument(
"--tp_degree", "--offload_optim_frac",
type=int, type=float,
default=1, default=1.0,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
) )
parser.add_argument('-p', parser.add_argument('-p',
'--plugin', '--plugin',
...@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module): ...@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
return total_numel return total_numel
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # column slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args() args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]: if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error") raise TypeError(f"{args.distplan} is error")
...@@ -212,23 +151,18 @@ if args.distplan == "colossalai": ...@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"plugin: {plugin}") logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
default_pg = ProcessGroup(tp_degree=args.tp_degree) ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx: with ctx:
model = PaLM(num_tokens=50304, dim=4096, depth=64) model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
tensor_parallelize(model, pg)
# optimizer # optimizer
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
......
...@@ -3,5 +3,5 @@ torch >= 1.8.1 ...@@ -3,5 +3,5 @@ torch >= 1.8.1
datasets >= 1.8.0 datasets >= 1.8.0
sentencepiece != 0.1.92 sentencepiece != 0.1.92
protobuf protobuf
accelerate == 0.13.2 accelerate >= 0.20.3
transformers transformers
...@@ -30,7 +30,7 @@ from itertools import chain ...@@ -30,7 +30,7 @@ from itertools import chain
import datasets import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers.utils.logging as logging
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
...@@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger ...@@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader from colossalai.utils import get_current_device, get_dataloader
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero import GeminiOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
...@@ -292,10 +292,10 @@ def main(): ...@@ -292,10 +292,10 @@ def main():
if is_main_process: if is_main_process:
datasets.utils.logging.set_verbosity_warning() datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info() logging.set_verbosity_info()
else: else:
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() logging.set_verbosity_error()
if args.mem_cap > 0: if args.mem_cap > 0:
colo_memory_cap(args.mem_cap) colo_memory_cap(args.mem_cap)
...@@ -391,16 +391,28 @@ def main(): ...@@ -391,16 +391,28 @@ def main():
else: else:
init_dev = get_current_device() init_dev = get_current_device()
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
# build model # build model
if version.parse(cai_version) >= version.parse("0.3.1"):
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
ctx = LazyInitContext(
default_device=init_dev
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
else:
from colossalai.zero import ColoInitContext
ctx = ColoInitContext(device=init_dev)
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
# currently, there has a bug in pretrained opt-13b # currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it # we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0]) logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev): with ctx:
model = OPTForCausalLM(config) model = OPTForCausalLM(config)
else: else:
logger.info("Finetune a pre-trained model", ranks=[0]) logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev): with ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
...@@ -410,9 +422,10 @@ def main(): ...@@ -410,9 +422,10 @@ def main():
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
PLACEMENT_POLICY = 'auto' PLACEMENT_POLICY = 'auto'
cai_version = colossalai.__version__ if version.parse(cai_version) >= version.parse("0.3.1"):
logger.info(f'using Colossal-AI version {cai_version}') from colossalai.zero import GeminiDDP
if version.parse(cai_version) > version.parse("0.1.10"): model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
elif version.parse(cai_version) > version.parse("0.1.10"):
try: try:
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
except ImportError: except ImportError:
...@@ -536,7 +549,6 @@ def main(): ...@@ -536,7 +549,6 @@ def main():
] ]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
...@@ -551,6 +563,7 @@ def main(): ...@@ -551,6 +563,7 @@ def main():
num_warmup_steps=args.num_warmup_steps, num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps, num_training_steps=args.max_train_steps,
) )
optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
......
...@@ -4,9 +4,9 @@ set -xue ...@@ -4,9 +4,9 @@ set -xue
pip install -r requirements.txt pip install -r requirements.txt
BS=8 BS=4
MEMCAP=0 MEMCAP=0
GPUNUM=2 GPUNUM=4
MODLE="facebook/opt-125m" MODLE="facebook/opt-125m"
torchrun \ torchrun \
......
...@@ -197,11 +197,12 @@ def get_cuda_cc_flag() -> List[str]: ...@@ -197,11 +197,12 @@ def get_cuda_cc_flag() -> List[str]:
import torch import torch
cc_flag = [] cc_flag = []
max_arch = ''.join(str(i) for i in torch.cuda.get_device_capability())
for arch in torch.cuda.get_arch_list(): for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch) res = re.search(r'sm_(\d+)', arch)
if res: if res:
arch_cap = res[1] arch_cap = res[1]
if int(arch_cap) >= 60: if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
return cc_flag return cc_flag
......
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
markers = markers =
dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx
\ No newline at end of file
...@@ -17,6 +17,13 @@ def data_gen_fn(): ...@@ -17,6 +17,13 @@ def data_gen_fn():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def data_gen_for_pretrain():
inputs = data_gen_fn()
inputs['labels'] = inputs['input_ids'].clone()
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
return inputs
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
config = transformers.AlbertConfig(embedding_size=128, config = transformers.AlbertConfig(embedding_size=128,
...@@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128, ...@@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128,
intermediate_size=256) intermediate_size=256)
model_zoo.register(name='transformers_albert', model_zoo.register(name='transformers_albert',
model_fn=lambda: transformers.AlbertModel(config), model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen_fn, data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_pretraining', model_zoo.register(name='transformers_albert_for_pretraining',
model_fn=lambda: transformers.AlbertForPreTraining(config), model_fn=lambda: transformers.AlbertForPreTraining(config),
data_gen_fn=data_gen_fn, data_gen_fn=data_gen_for_pretrain,
output_transform_fn=output_transform_fn, output_transform_fn=lambda x: dict(loss=x.loss),
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_masked_lm', model_zoo.register(name='transformers_albert_for_masked_lm',
model_fn=lambda: transformers.AlbertForMaskedLM(config), model_fn=lambda: transformers.AlbertForMaskedLM(config),
......
...@@ -113,6 +113,7 @@ def data_gen_for_qa(): ...@@ -113,6 +113,7 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
)) ))
loss_fn = lambda x: x.loss loss_fn = lambda x: x.loss
...@@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128, ...@@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128,
# register the BERT variants # register the BERT variants
model_zoo.register(name='transformers_bert', model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config), model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen, data_gen_fn=data_gen,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model, loss_fn=loss_fn_for_bert_model,
......
...@@ -57,6 +57,12 @@ def data_gen_for_sequence_classification(): ...@@ -57,6 +57,12 @@ def data_gen_for_sequence_classification():
return data return data
def date_gen_for_double_heads():
data = data_gen_for_lm()
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
return data
# define output transform function # define output transform function
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
...@@ -94,8 +100,8 @@ model_zoo.register(name='transformers_gpt_lm', ...@@ -94,8 +100,8 @@ model_zoo.register(name='transformers_gpt_lm',
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads', model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=data_gen_for_lm, data_gen_fn=date_gen_for_double_heads,
output_transform_fn=output_transform_fn, output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
loss_fn=loss_fn, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering', model_zoo.register(name='transformers_gpt_for_question_answering',
......
...@@ -12,19 +12,16 @@ from colossalai.lazy.lazy_init import LazyInitContext ...@@ -12,19 +12,16 @@ from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try: try:
if init_method == 'colo': if init_method == 'lazy':
ctx = ColoInitContext()
elif init_method == 'lazy':
ctx = LazyInitContext() ctx = LazyInitContext()
else: else:
ctx = nullcontext() ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
with ctx: with ctx:
model = model_fn() model = model_fn()
...@@ -50,6 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ ...@@ -50,6 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
optimizer.step() optimizer.step()
except Exception as e: except Exception as e:
# raise e
return repr(e) return repr(e)
...@@ -57,8 +55,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ ...@@ -57,8 +55,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
# @parameterize('init_method', ['lazy', 'none', 'colo']) # @parameterize('init_method', ['lazy', 'none', 'colo'])
@parameterize('subset', ['torchvision', 'transformers', 'diffusers'])
@parameterize('init_method', ['none']) @parameterize('init_method', ['none'])
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo """check gemini plugin over model zoo
Args: Args:
...@@ -71,29 +70,23 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): ...@@ -71,29 +70,23 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
# These models lead to CUDA error # These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
'torchvision_convnext_base'):
continue continue
# These models are not compatible with gemini # These models are not compatible with gemini
if name in [ if name in [
'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet', 'timm_convit',
'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0', 'timm_dm_nfnet',
'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer', 'torchvision_vit_b_16',
'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron', 'transformers_t5',
'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch', 'transformers_t5_for_conditional_generation',
'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small', 'transformers_t5_encoder_model', # does not support apex rmsnorm
'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2', 'transformers_chatglm',
'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert', 'transformers_sam',
'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining', 'transformers_vit'
'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base',
'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model',
'transformers_vit', 'transformers_vit_for_masked_image_modeling',
'transformers_vit_for_image_classification', 'transformers_chatglm',
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
]: ]:
continue continue
...@@ -105,7 +98,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): ...@@ -105,7 +98,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
]: ]:
continue continue
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
if err is None: if err is None:
passed_models.append(name) passed_models.append(name)
......
...@@ -18,12 +18,45 @@ from colossalai.testing import ( ...@@ -18,12 +18,45 @@ from colossalai.testing import (
) )
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
MODEL_PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 1.0
}, # zero3
{
'placement_policy': 'static',
'shard_param_frac': 0.5
}, # zero3-half
]
OPTIM_PLACEMENT_CONFIGS = [
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.0
}, # zero2
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 1.0
}, # zero2-offload
{
'placement_policy': 'static',
'shard_param_frac': 0.0,
'offload_optim_frac': 0.5
}, # zero2-offload-half
]
@clear_cache_before_run() @clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_config', MODEL_PLACEMENT_CONFIGS)
@parameterize('model_name', ['transformers_bert_for_sequence_classification']) @parameterize('model_name', ['transformers_bert_for_sequence_classification'])
@parameterize('use_safetensors', [False, True]) @parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn() bert_model = model_fn()
...@@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b ...@@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
pretrained_path = os.path.join(tempdir, 'pretrained') pretrained_path = os.path.join(tempdir, 'pretrained')
bert_model.config.save_pretrained(save_directory=pretrained_path) bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(placement_policy=placement_policy) plugin = GeminiPlugin(**placement_config)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model) bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
...@@ -46,19 +79,19 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b ...@@ -46,19 +79,19 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
dist.barrier() dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32), check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
new_bert_model.state_dict(), False) new_bert_model.state_dict(), False)
@clear_cache_before_run() @clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS)
@parameterize('shard', [False, True]) @parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt']) @parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32]) @parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14)) plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = model_fn() model = model_fn()
...@@ -87,12 +120,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha ...@@ -87,12 +120,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha
dist.barrier() dist.barrier()
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
new_model.unwrap().state_dict(only_rank_0=False), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False),
new_optimizer.unwrap().state_dict(only_rank_0=False), False) False)
# Check the new model/optimizer can successfully run. # Check the new model/optimizer can successfully run.
data = data_gen_fn() data = data_gen_fn()
......
...@@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): ...@@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
new_booster.load_model(new_model, model_ckpt_path, strict=True) new_booster.load_model(new_model, model_ckpt_path, strict=True)
# Add prefix to get aligned with pytorch parameter names. # Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal( check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), new_model.state_dict(), False)
new_model.state_dict(), False)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False) check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
# Check the new model/optimizer can successfully run. # Check the new model/optimizer can successfully run.
data = data_gen_fn() data = data_gen_fn()
...@@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): ...@@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
new_booster.load_model(new_model, model_ckpt_path, strict=True) new_booster.load_model(new_model, model_ckpt_path, strict=True)
# Add prefix to get aligned with pytorch parameter names. # Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal( check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), model.state_dict(), False)
model.state_dict(), False)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
old_state_dict = optimizer.state_dict() old_state_dict = optimizer.state_dict()
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False) new_state_dict = new_optimizer.state_dict(only_rank_0=False)
# Comparison of param_groups needs special care here, # Comparison of param_groups needs special care here,
# since not all hyperparameters in Adam are used by HybridAdam # since not all hyperparameters in Adam are used by HybridAdam
...@@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): ...@@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']): for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
for k in hyperparameters_to_examine: for k in hyperparameters_to_examine:
assert k in old_group and k in new_group, \ assert k in old_group and k in new_group, \
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k] assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False) check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
......
...@@ -16,19 +16,21 @@ from colossalai.testing import ( ...@@ -16,19 +16,21 @@ from colossalai.testing import (
) )
# stage 1 and 2 process the optimizer/mode the same way
# only test 2 is fine
@clear_cache_before_run() @clear_cache_before_run()
@parameterize('stage', [2]) @parameterize('stage', [2])
@parameterize('shard', [True, False]) @parameterize('shard', [True, False])
def check_low_level_zero_checkpointIO(stage: int, shard: bool): @parameterize('offload', [False, True])
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = resnet18() model = resnet18()
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
optimizer = HybridAdam((model.parameters()), lr=0.001) optimizer = HybridAdam((model.parameters()), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
x = torch.randn(4, 3, 224, 224) x = torch.randn(1, 3, 224, 224, device='cuda')
x = x.to('cuda')
output = model(x) output = model(x)
loss = criterion(output) loss = criterion(output)
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
...@@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): ...@@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
check_low_level_zero_checkpointIO() check_low_level_zero_checkpointIO()
torch.cuda.empty_cache()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_low_level_zero_checkpointIO(): def test_low_level_zero_checkpointIO():
spawn(run_dist, 2) spawn(run_dist, 2)
......
import os
from pathlib import Path
import pytest
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
import colossalai
from colossalai.amp import AMP_TYPE
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.trainer import Trainer, hooks
from colossalai.utils import get_dataloader
disable_existing_loggers()
BATCH_SIZE = 4
NUM_EPOCHS = 10
WARMUP_EPOCHS = 5
CONFIG = dict(NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
def run_trainer(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
disable_existing_loggers()
# get logger
logger = get_dist_logger()
pipelinable = PipelinableContext()
try:
from titans.model.vit import vit_tiny_patch4_32
except ImportError:
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
return
with pipelinable:
model = vit_tiny_patch4_32()
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
# create dataloaders
root = Path(os.environ['DATA'])
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
# initialize
engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
logger = get_dist_logger()
trainer = Trainer(engine=engine, logger=logger)
hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
]
trainer.fit(train_dataloader=train_dataloader,
max_steps=2,
epochs=NUM_EPOCHS,
hooks=hook_list,
display_progress=True)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_hybrid_parallel():
spawn(run_trainer, 2)
disable_existing_loggers()
if __name__ == '__main__':
test_hybrid_parallel()
import os
import random
from typing import Callable, Type
import numpy as np
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def init_ddp(module: torch.nn.Module) -> ColoDDP:
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(3, 3, bias=False)
self.fc2 = torch.nn.Linear(3, 1, bias=False)
def forward(self, x):
return self.fc2(self.fc1(x))
def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
with ColoInitContext(device=get_current_device()):
model = Net().cuda()
w1 = model.fc1.weight
w2 = model.fc2.weight
ddp_cls.set_params_to_ignore([w2])
model = init_ddp_func(model)
x = torch.rand(2, 3, device=get_current_device())
logits = model(x)
loss = torch.sum(logits)
model.backward(loss)
if ddp_cls is ZeroDDP:
w1s_grad = w1
else:
w1s_grad = w1.grad
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
dist.all_gather(w1_grads, w1s_grad)
assert torch.equal(w1_grads[0], w1_grads[1])
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
dist.all_gather(w2_grads, w2.grad)
assert not torch.equal(w2_grads[0], w2_grads[1])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
set_seed(dist.get_rank())
run_fwd_bwd(ColoDDP, init_ddp)
run_fwd_bwd(ZeroDDP, init_ddpv2)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_ddp_ignore_params(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_ddp_ignore_params(2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment