Unverified Commit 26b7aac0 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
parent b09adff7
...@@ -4,7 +4,7 @@ from typing import Dict, Optional ...@@ -4,7 +4,7 @@ from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .gemini_parallel import GeminiDDP from .gemini import GeminiDDP
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
...@@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module, ...@@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module,
config_dict['max_scale'] = max_scale config_dict['max_scale'] = max_scale
if zero_stage in [1, 2]: if zero_stage in [1, 2]:
from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict['partition_grad'] = zero_stage == 2 config_dict['partition_grad'] = zero_stage == 2
config_dict['clip_grad_norm'] = max_norm config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict) return LowLevelZeroOptimizer(optimizer, **config_dict)
else: else:
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
config_dict['clipping_norm'] = max_norm config_dict['clipping_norm'] = max_norm
return ZeroOptimizer(optimizer, model, **config_dict) return ZeroOptimizer(optimizer, model, **config_dict)
...@@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel ...@@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
``` ```
......
...@@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config ...@@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
``` ```
......
...@@ -5,7 +5,7 @@ torchrun --standalone --nproc_per_node=1 debug.py ...@@ -5,7 +5,7 @@ torchrun --standalone --nproc_per_node=1 debug.py
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
import colossalai import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx from colossalai.zero import ColoInitContext, post_process_colo_init_ctx
path = "/data/scratch/diffuser/stable-diffusion-v1-4" path = "/data/scratch/diffuser/stable-diffusion-v1-4"
......
...@@ -21,10 +21,9 @@ import colossalai ...@@ -21,10 +21,9 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
......
...@@ -23,10 +23,9 @@ import colossalai ...@@ -23,10 +23,9 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
......
...@@ -18,7 +18,7 @@ from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, Proc ...@@ -18,7 +18,7 @@ from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, Proc
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext
def set_seed(seed): def set_seed(seed):
......
...@@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext
def init_1d_row_for_linear_weight_spec(model, world_size: int): def init_1d_row_for_linear_weight_spec(model, world_size: int):
......
...@@ -12,10 +12,9 @@ from transformers import AlbertConfig, AlbertForSequenceClassification, BertConf ...@@ -12,10 +12,9 @@ from transformers import AlbertConfig, AlbertForSequenceClassification, BertConf
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__
......
...@@ -13,10 +13,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -13,10 +13,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__
......
...@@ -34,12 +34,9 @@ from transformers.utils.versions import require_version ...@@ -34,12 +34,9 @@ from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
...@@ -179,13 +176,15 @@ def main(): ...@@ -179,13 +176,15 @@ def main():
# build model # build model
if args.model_name_or_path is None: if args.model_name_or_path is None:
logger.info("Train a new model from scratch", ranks=[0]) logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half, with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec, default_dist_spec=default_dist_spec,
default_pg=shard_pg): default_pg=shard_pg):
model = OPTForCausalLM(config) model = OPTForCausalLM(config)
else: else:
logger.info("Finetune a pre-trained model", ranks=[0]) logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half, with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec, default_dist_spec=default_dist_spec,
default_pg=shard_pg): default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
...@@ -198,8 +197,11 @@ def main(): ...@@ -198,8 +197,11 @@ def main():
numel = sum([p.numel() for p in model.parameters()]) numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu' PLACEMENT_POLICY = 'cpu'
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, model = GeminiDDP(model,
pin_memory=True, strict_ddp_mode=args.shardinit) device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024 SEQ_LEN = 1024
......
...@@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset ...@@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
# constants # constants
...@@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: ...@@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
return model return model
## Parameter Sharding Strategies for Tensor Parallelism # Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec) param.set_tensor_spec(*spec)
...@@ -232,7 +230,7 @@ if args.distplan == "colossalai": ...@@ -232,7 +230,7 @@ if args.distplan == "colossalai":
tensor_parallelize(model, pg) tensor_parallelize(model, pg)
model = gemini_zero_dpp(model, pg, args.placement) model = gemini_zero_dpp(model, pg, args.placement)
#optimizer # optimizer
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
......
import colossalai
import math import math
import os
import time
from functools import partial
import torch import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
import colossalai.nn as col_nn
from arguments import parse_args from arguments import parse_args
from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt
from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args
from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer
from utils.logger import Logger
from evaluation import evaluate from evaluation import evaluate
from loss import LossForPretraining from loss import LossForPretraining
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
from tqdm import tqdm from tqdm import tqdm
import os
import time
from functools import partial
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger
from colossalai.gemini import ChunkManager, GeminiManager import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext import colossalai.nn as col_nn
from colossalai.utils import get_current_device from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device
from colossalai.zero import ZeroOptimizer
from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager
from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
def main(): def main():
args = parse_args() args = parse_args()
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
if args.vscode_debug: if args.vscode_debug:
colossalai.launch(config={}, colossalai.launch(config={},
rank=args.rank, rank=args.rank,
world_size=args.world_size, world_size=args.world_size,
host=args.host, host=args.host,
port=args.port, port=args.port,
backend=args.backend) backend=args.backend)
args.local_rank = -1 args.local_rank = -1
args.log_interval = 1 args.log_interval = 1
else: else:
colossalai.launch_from_torch(args.colossal_config) #args.colossal_config colossalai.launch_from_torch(args.colossal_config) # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"]) args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + logger.info(
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}'
)
log_args(logger, args) log_args(logger, args)
args.tokenizer = tokenizer args.tokenizer = tokenizer
args.logger = logger args.logger = logger
set_global_variables(launch_time, args.tensorboard_path) set_global_variables(launch_time, args.tensorboard_path)
use_zero = hasattr(gpc.config, 'zero') use_zero = hasattr(gpc.config, 'zero')
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
...@@ -71,8 +69,8 @@ def main(): ...@@ -71,8 +69,8 @@ def main():
if use_zero: if use_zero:
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
shard_param=True): shard_param=True):
config, model, numel = get_model(args, logger) config, model, numel = get_model(args, logger)
# model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
else: else:
...@@ -82,9 +80,10 @@ def main(): ...@@ -82,9 +80,10 @@ def main():
os.mkdir(os.path.join(args.ckpt_path, launch_time)) os.mkdir(os.path.join(args.ckpt_path, launch_time))
logger.info(f'Model numel: {numel}') logger.info(f'Model numel: {numel}')
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) # len(dataloader)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
total_steps = steps_per_epoch * args.epoch total_steps = steps_per_epoch * args.epoch
# build optimizer and lr_scheduler # build optimizer and lr_scheduler
...@@ -98,18 +97,23 @@ def main(): ...@@ -98,18 +97,23 @@ def main():
o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
optimizer = get_optimizer(model, lr=args.lr) optimizer = get_optimizer(model, lr=args.lr)
optimizer.load_state_dict(o_l_state_dict['optimizer']) optimizer.load_state_dict(o_l_state_dict['optimizer'])
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] # o_l_state_dict['lr_scheduler']['last_epoch']
lr_scheduler = get_lr_scheduler(optimizer,
total_steps=total_steps,
last_epoch=o_l_state_dict['lr_scheduler']['last_epoch'])
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
# if you want delete the above three code, have to move the model to gpu, because in optimizer.step() # if you want delete the above three code, have to move the model to gpu, because in optimizer.step()
lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
start_epoch = o_l_state_dict['epoch'] start_epoch = o_l_state_dict['epoch']
start_shard = o_l_state_dict['shard'] + 1 start_shard = o_l_state_dict['shard'] + 1
# global_step = o_l_state_dict['global_step'] + 1 # global_step = o_l_state_dict['global_step'] + 1
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') logger.info(
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
)
else: else:
optimizer = get_optimizer(model, lr=args.lr) optimizer = get_optimizer(model, lr=args.lr)
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
...@@ -124,12 +128,11 @@ def main(): ...@@ -124,12 +128,11 @@ def main():
# initialize with colossalai # initialize with colossalai
engine, _, _, lr_scheduelr = colossalai.initialize(model=model, engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
logger.info(get_mem_info(prefix='After init model, ')) logger.info(get_mem_info(prefix='After init model, '))
best_loss = None best_loss = None
eval_loss = 0 eval_loss = 0
...@@ -146,13 +149,16 @@ def main(): ...@@ -146,13 +149,16 @@ def main():
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) iterator_data = tqdm(enumerate(dataset_iterator),
total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
colour='cyan',
smoothing=1)
else: else:
iterator_data = enumerate(dataset_iterator) iterator_data = enumerate(dataset_iterator)
engine.train() engine.train()
for step, batch_data in iterator_data: for step, batch_data in iterator_data:
# batch_data = pretrain_dataset_provider.get_batch(batch_index) # batch_data = pretrain_dataset_provider.get_batch(batch_index)
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
...@@ -162,7 +168,7 @@ def main(): ...@@ -162,7 +168,7 @@ def main():
# nsp_label = batch_data[5].cuda() # nsp_label = batch_data[5].cuda()
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = engine.criterion(output.logits, mlm_label) loss = engine.criterion(output.logits, mlm_label)
pretrain_dataset_provider.prefetch_batch() pretrain_dataset_provider.prefetch_batch()
...@@ -172,14 +178,15 @@ def main(): ...@@ -172,14 +178,15 @@ def main():
engine.step() engine.step()
lr_scheduelr.step() lr_scheduelr.step()
engine.zero_grad() engine.zero_grad()
global_step += 1 global_step += 1
if global_step % args.log_interval == 0 and global_step != 0 \ if global_step % args.log_interval == 0 and global_step != 0 \
and torch.distributed.get_rank() == 0: and torch.distributed.get_rank() == 0:
elapsed_time = timers('interval_time').elapsed(reset=False) elapsed_time = timers('interval_time').elapsed(reset=False)
elapsed_time_per_iteration = elapsed_time / global_step elapsed_time_per_iteration = elapsed_time / global_step
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
numel, args, config, elapsed_time, global_step, world_size)
cur_loss = train_loss / args.log_interval cur_loss = train_loss / args.log_interval
current_lr = lr_scheduelr.get_last_lr()[0] current_lr = lr_scheduelr.get_last_lr()[0]
...@@ -189,12 +196,13 @@ def main(): ...@@ -189,12 +196,13 @@ def main():
if args.wandb: if args.wandb:
tensorboard_log = get_tensorboard_writer() tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_train({ tensorboard_log.log_train(
'lr': current_lr, {
'loss': cur_loss, 'lr': current_lr,
'ppl': math.exp(cur_loss), 'loss': cur_loss,
'mins_batch': elapsed_time_per_iteration 'ppl': math.exp(cur_loss),
}, global_step) 'mins_batch': elapsed_time_per_iteration
}, global_step)
train_loss = 0 train_loss = 0
...@@ -202,12 +210,14 @@ def main(): ...@@ -202,12 +210,14 @@ def main():
logger.info('*' * 100) logger.info('*' * 100)
eval_loss += evaluate(engine, args, logger, global_step) eval_loss += evaluate(engine, args, logger, global_step)
save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) save_ckpt(engine.model, optimizer, lr_scheduelr,
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
shard, global_step)
eval_loss /= len(os.listdir(args.data_path_prefix)) eval_loss /= len(os.listdir(args.data_path_prefix))
logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ logger.info(
f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
+ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
logger.info('-' * 100) logger.info('-' * 100)
if args.wandb and torch.distributed.get_rank() == 0: if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer() tensorboard_log = get_tensorboard_writer()
......
...@@ -30,24 +30,13 @@ from itertools import chain ...@@ -30,24 +30,13 @@ from itertools import chain
import datasets import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
from packaging import version from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -61,6 +50,15 @@ from transformers import ( ...@@ -61,6 +50,15 @@ from transformers import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
......
...@@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize ...@@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed
......
...@@ -11,12 +11,11 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -11,12 +11,11 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.process_group import ProcessGroup from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
......
...@@ -10,14 +10,14 @@ import torch.distributed as dist ...@@ -10,14 +10,14 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.nn.parallel import ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
def set_seed(seed): def set_seed(seed):
......
import copy import copy
from collections import OrderedDict
from functools import partial
import pytest import pytest
import colossalai
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.cuda import get_current_device
from functools import partial from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ColoDDP
from collections import OrderedDict
from colossalai.tensor import ProcessGroup, ColoParameter
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
......
import pytest import pytest
import torch import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
@pytest.mark.dist @pytest.mark.dist
def test_gemini_manager(): def test_gemini_manager():
# reset the manager, in case that there exists memory information left # reset the manager, in case that there exists memory information left
manager = StatefulTensor.GST_MGR manager = StatefulTensor.GST_MGR
manager.reset() manager.reset()
# occupation 8 # occupation 8
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
# occupation 60 # occupation 60
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
# occupation 28 # occupation 28
t1 = torch.empty(7, device='cuda') t1 = torch.empty(7, device='cuda')
# occupation 12 # occupation 12
t2 = torch.empty(3, device='cpu') t2 = torch.empty(3, device='cpu')
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
st4 = StatefulTensor(None, TensorState.FREE) st4 = StatefulTensor(None, TensorState.FREE)
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 60 assert manager.total_mem['cpu'] == 60
assert manager.total_mem['cuda'] == 36 assert manager.total_mem['cuda'] == 36
assert manager.state_mem['cpu'][TensorState.HOLD] == 60 assert manager.state_mem['cpu'][TensorState.HOLD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD] == 8 assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
st4.payload_reset(t2) st4.payload_reset(t2)
st3.payload_reset(t2) st3.payload_reset(t2)
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 84 assert manager.total_mem['cpu'] == 84
assert manager.total_mem['cuda'] == 8 assert manager.total_mem['cuda'] == 8
assert manager.state_mem['cpu'][TensorState.HOLD] == 72 assert manager.state_mem['cpu'][TensorState.HOLD] == 72
assert manager.state_mem['cuda'][TensorState.HOLD] == 8 assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
st1.move_to(torch.device('cpu')) st1.move_to(torch.device('cpu'))
st2.move_to(torch.device('cpu')) st2.move_to(torch.device('cpu'))
st3.move_to(torch.device('cuda', 0)) st3.move_to(torch.device('cuda', 0))
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80 assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12 assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 80 assert manager.state_mem['cpu'][TensorState.HOLD] == 80
assert manager.state_mem['cuda'][TensorState.HOLD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
st1.trans_state(TensorState.COMPUTE) st1.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.COMPUTE) st2.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.HOLD_AFTER_BWD) st2.trans_state(TensorState.HOLD_AFTER_BWD)
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80 assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12 assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 12 assert manager.state_mem['cpu'][TensorState.HOLD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
if __name__ == '__main__': if __name__ == '__main__':
test_gemini_manager() test_gemini_manager()
...@@ -2,7 +2,7 @@ import copy ...@@ -2,7 +2,7 @@ import copy
import torch import torch
from colossalai.gemini.paramhooks import BaseParamHookMgr from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment