Commit 94586767 authored by Zeqiang Lai's avatar Zeqiang Lai Committed by zhe chen
Browse files

[Classification] support deepspeed, fix optimizer bugs #(83)

parent 2d975df6
DATA:
IMG_ON_MEMORY: True
MODEL:
TYPE: intern_image
DROP_PATH_RATE: 0.4
INTERN_IMAGE:
CORE_OP: 'DCNv3'
DEPTHS: [4, 4, 21, 4]
GROUPS: [5, 10, 20, 40]
CHANNELS: 80
LAYER_SCALE: 1e-5
OFFSET_SCALE: 1.0
MLP_RATIO: 4.0
POST_NORM: True
TRAIN:
EMA:
ENABLE: True
DECAY: 0.9999
BASE_LR: 5e-4
DATA:
IMG_ON_MEMORY: True
MODEL:
TYPE: intern_image
DROP_PATH_RATE: 0.1
INTERN_IMAGE:
CORE_OP: 'DCNv3'
DEPTHS: [4, 4, 18, 4]
GROUPS: [4, 8, 16, 32]
CHANNELS: 64
OFFSET_SCALE: 1.0
MLP_RATIO: 4.0
TRAIN:
EMA:
ENABLE: True
DECAY: 0.9999
BASE_LR: 5e-4
DATA:
IMG_SIZE: 384
IMG_ON_MEMORY: True
AUG:
MIXUP: 0.0
CUTMIX: 0.0
REPROB: 0.0
MODEL:
TYPE: intern_image
DROP_PATH_RATE: 0.2
LABEL_SMOOTHING: 0.3
INTERN_IMAGE:
CORE_OP: 'DCNv3'
DEPTHS: [5, 5, 24, 5]
GROUPS: [12, 24, 48, 96]
CHANNELS: 192
LAYER_SCALE: 1e-5
OFFSET_SCALE: 2.0
MLP_RATIO: 4.0
POST_NORM: True
TRAIN:
EMA:
ENABLE: true
DECAY: 0.9999
EPOCHS: 20
WARMUP_EPOCHS: 2
WEIGHT_DECAY: 0.05
BASE_LR: 2e-05 # 512
WARMUP_LR: .0
MIN_LR: .0
LR_LAYER_DECAY: true
LR_LAYER_DECAY_RATIO: 0.9
USE_CHECKPOINT: true
OPTIMIZER:
DCN_LR_MUL: 0.1
AMP_OPT_LEVEL: O0
EVAL_FREQ: 1
\ No newline at end of file
......@@ -4,4 +4,4 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from .build import build_loader
\ No newline at end of file
from .build import build_loader, build_loader2
\ No newline at end of file
......@@ -147,6 +147,58 @@ def build_loader(config):
data_loader_val, data_loader_test, mixup_fn
def build_loader2(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset('train',
config=config)
config.freeze()
dataset_val, _ = build_dataset('val', config=config)
dataset_test, _ = build_dataset('test', config=config)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
shuffle=True,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
persistent_workers=True) if dataset_train is not None else None
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
persistent_workers=True) if dataset_val is not None else None
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
persistent_workers=True) if dataset_test is not None else None
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(mixup_alpha=config.AUG.MIXUP,
cutmix_alpha=config.AUG.CUTMIX,
cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB,
switch_prob=config.AUG.MIXUP_SWITCH_PROB,
mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING,
num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, dataset_test, data_loader_train, \
data_loader_val, data_loader_test, mixup_fn
def build_dataset(split, config):
transform = build_transform(split == 'train', config)
dataset = None
......
import datetime
import argparse
import os
import time
import logging
import random
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from accelerate import Accelerator
from accelerate import GradScalerKwargs
from accelerate.logging import get_logger
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import AverageMeter, accuracy, ModelEma
from tqdm import tqdm
import warnings
from config import get_config
from models import build_model
from dataset import build_loader2
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from utils import load_pretrained, load_ema_checkpoint
from ddp_hooks import fp16_compress_hook
logger = get_logger(__name__)
warnings.filterwarnings('ignore')
def parse_option():
parser = argparse.ArgumentParser(
'InternImage training and evaluation script', add_help=False)
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
parser.add_argument("--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+')
# easy config modification
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
parser.add_argument('--dataset', type=str, help='dataset name', default=None)
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece'
)
parser.add_argument('--pretrained', help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--output', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)'
)
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--save-ckpt-num', default=1, type=int)
parser.add_argument('--accumulation-steps', type=int, default=1, help="gradient accumulation steps")
parser.add_argument('--disable-grad-scalar', action='store_true', help='disable Grad Scalar')
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
help=(
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
" for experiment tracking and logging of model metrics and model checkpoints"
),
)
args, unparsed = parser.parse_known_args()
config = get_config(args)
config.defrost()
config.TRAIN.OPTIMIZER.USE_ZERO = False
config.OUTPUT += '_deepspeed'
config.DATA.IMG_ON_MEMORY = False
config.freeze()
return args, config
def seed_everything(seed, rank):
seed = seed + rank
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
def save_config(config):
path = os.path.join(config.OUTPUT, "config.json")
with open(path, "w") as f:
f.write(config.dump())
logger.info(f"Full config saved to {path}")
def build_criterion(config):
if config.AUG.MIXUP > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif config.MODEL.LABEL_SMOOTHING > 0.:
criterion = LabelSmoothingCrossEntropy(
smoothing=config.MODEL.LABEL_SMOOTHING)
else:
criterion = torch.nn.CrossEntropyLoss()
return criterion
def scale_learning_rate(config, num_processes):
# linear scale the learning rate according to total batch size, may not be optimal
linear_scaled_lr = config.TRAIN.BASE_LR * \
config.DATA.BATCH_SIZE * num_processes / 512.0
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * \
config.DATA.BATCH_SIZE * num_processes / 512.0
linear_scaled_min_lr = config.TRAIN.MIN_LR * \
config.DATA.BATCH_SIZE * num_processes / 512.0
# gradient accumulation also need to scale the learning rate
if config.TRAIN.ACCUMULATION_STEPS > 1:
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
config.freeze()
logger.info('BASE_LR={}'.format(config.TRAIN.BASE_LR))
logger.info('WARMUP_LR={}'.format(config.TRAIN.WARMUP_LR))
logger.info('MIN_LR={}'.format(config.TRAIN.MIN_LR))
def setup_autoresume(config):
if config.MODEL.RESUME == '' and config.TRAIN.AUTO_RESUME:
last_checkpoint = os.path.join(config.OUTPUT, 'last')
resume_file = last_checkpoint if os.path.exists(last_checkpoint) else None
if resume_file:
if config.MODEL.RESUME:
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
config.defrost()
config.MODEL.RESUME = resume_file
config.freeze()
logger.info(f'auto resuming from {resume_file}')
else:
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
def load_model_checkpoint(config, model, accelerator):
if config.MODEL.RESUME:
try:
checkpoint = torch.load(config.MODEL.RESUME)['model']
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint)
except:
accelerator.load_state(config.MODEL.RESUME)
elif config.MODEL.PRETRAINED:
try:
load_pretrained(config, model, logger)
except:
accelerator.load_state(config.MODEL.PRETRAINED)
return model
def save_checkpoint(save_dir, accelerator, epoch, max_acc, config, lr_scheduler=None):
# let accelerator handle the model and optimizer state for ddp and deepspeed.
accelerator.save_state(save_dir)
if accelerator.is_main_process:
save_state = {
'lr_scheduler': lr_scheduler.state_dict(),
'max_acc': max_acc,
'epoch': epoch,
'config': config
}
torch.save(save_state, os.path.join(save_dir, 'additional_state.pth'))
def load_checkpoint_if_needed(accelerator, config, lr_scheduler=None):
setup_autoresume(config)
save_dir = config.MODEL.RESUME
if not save_dir:
return 0.0
accelerator.load_state(save_dir)
checkpoint = torch.load(os.path.join(save_dir, 'additional_state.pth'), map_location='cpu')
if lr_scheduler is not None:
logger.info('resuming lr_scheduler')
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
max_acc = checkpoint.get('max_acc', 0.0)
logger.info(f"=> loaded successfully {config.MODEL.RESUME} (epoch {checkpoint['epoch']})")
return max_acc
def log_model_statistic(model_wo_ddp):
n_parameters = sum(p.numel() for p in model_wo_ddp.parameters()
if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
if hasattr(model_wo_ddp, 'flops'):
flops = model_wo_ddp.flops()
logger.info(f"number of GFLOPs: {flops / 1e9}")
def train_epoch(*, model, optimizer, data_loader, scheduler, criterion, mixup_fn,
accelerator: Accelerator, epoch, config):
model.train()
num_steps = len(data_loader)
batch_time = AverageMeter()
model_time = AverageMeter()
loss_meter = AverageMeter()
end = time.time()
gradient_accumulation_steps = config.TRAIN.ACCUMULATION_STEPS
for step, (samples, targets) in enumerate(data_loader):
iter_begin_time = time.time()
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
with accelerator.accumulate(model):
outputs = model(samples)
loss = criterion(outputs, targets)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone()
if (step + 1) % gradient_accumulation_steps == 0:
if scheduler is not None:
scheduler.step_update((epoch * num_steps + step) // gradient_accumulation_steps)
batch_time.update(time.time() - end)
model_time.update(time.time() - iter_begin_time)
loss_meter.update(loss.item())
end = time.time()
if accelerator.is_main_process and step % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
etas = batch_time.avg * (num_steps - step)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{step}/{num_steps}]\t'
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.10f}\t'
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
f'model_time {model_time.val:.4f} ({model_time.avg:.4f})\t'
f'loss {loss_meter.val:.8f} ({loss_meter.avg:.4f})\t'
f'mem {memory_used:.0f}MB')
@torch.no_grad()
def eval_epoch(*, config, data_loader, model, accelerator: Accelerator):
model.eval()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
for idx, (images, target) in enumerate(tqdm(data_loader, disable=accelerator.is_main_process)):
output = model(images)
# convert 22k to 1k to evaluate
if output.size(-1) == 21841:
convert_file = './meta_data/map22kto1k.txt'
with open(convert_file, 'r') as f:
convert_list = [int(line) for line in f.readlines()]
output = output[:, convert_list]
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 = accelerator.gather(acc1).mean(0)
acc5 = accelerator.gather(acc5).mean(0)
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
if (idx + 1) % config.PRINT_FREQ == 0 or idx + 1 == len(data_loader):
logger.info(f'Test: [{idx+1}/{len(data_loader)}]\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
)
return acc1_meter.avg
def eval(config, accelerator: Accelerator):
_, _, _, _, validate_dataloader, _, _ = build_loader2(config)
model = build_model(config)
model, validate_dataloader = accelerator.prepare(model, validate_dataloader)
model = load_model_checkpoint(config, model, accelerator)
log_model_statistic(accelerator.unwrap_model(model))
eval_epoch(config=config, data_loader=validate_dataloader, model=model, accelerator=accelerator)
def train(config, accelerator: Accelerator):
_, _, _, training_dataloader, validate_dataloader, _, mixup_fn = build_loader2(config)
model = build_model(config)
optimizer = build_optimizer(config, model)
criterion = build_criterion(config)
model, optimizer, training_dataloader, validate_dataloader = accelerator.prepare(
model, optimizer, training_dataloader, validate_dataloader)
effective_update_steps_per_epoch = len(training_dataloader) // config.TRAIN.ACCUMULATION_STEPS
lr_scheduler = build_scheduler(config, optimizer, effective_update_steps_per_epoch)
try:
model.register_comm_hook(state=None, hook=fp16_compress_hook)
logger.info('using fp16_compress_hook!')
except:
logger.info("cannot register fp16_compress_hook!")
max_acc = load_checkpoint_if_needed(accelerator, config, lr_scheduler)
logger.info(f"Created model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
logger.info(str(model))
logger.info("Effective Optimizer Steps: {}".format(effective_update_steps_per_epoch))
logger.info("Start training")
logger.info("Max accuracy: {}".format(max_acc))
log_model_statistic(accelerator.unwrap_model(model))
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
train_epoch(model=model, optimizer=optimizer, data_loader=training_dataloader,
scheduler=lr_scheduler, criterion=criterion, mixup_fn=mixup_fn,
accelerator=accelerator, epoch=epoch, config=config)
acc = eval_epoch(config=config, data_loader=validate_dataloader, model=model,
accelerator=accelerator)
accelerator.wait_for_everyone()
if acc > max_acc:
max_acc = acc
save_checkpoint(os.path.join(config.OUTPUT, 'best'), accelerator, epoch, max_acc, config, lr_scheduler)
logger.info(f'Max Acc@1 {max_acc:.3f}')
save_checkpoint(os.path.join(config.OUTPUT, 'last'), accelerator, epoch, max_acc, config, lr_scheduler)
def main():
args, config = parse_option()
os.makedirs(config.OUTPUT, exist_ok=True)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
filename=os.path.join(config.OUTPUT, 'run.log'),
level=logging.INFO,
)
loggers = ['tensorboard']
accelerator = Accelerator(
log_with=loggers,
project_dir=config.OUTPUT,
gradient_accumulation_steps=config.TRAIN.ACCUMULATION_STEPS,
# When use deepspeed, you could not comment this out
# even if you set loss scale to 1.0 in deepspeed config.
kwargs_handlers=[GradScalerKwargs(enabled=not args.disable_grad_scalar)],
)
logger.info(accelerator.state, main_process_only=False)
scale_learning_rate(config, accelerator.num_processes)
seed_everything(config.SEED, accelerator.process_index)
save_config(config)
logger.info(config.dump())
if config.EVAL_MODE:
eval(config, accelerator)
else:
train(config, accelerator)
if __name__ == '__main__':
main()
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
import time
import random
import argparse
import datetime
import numpy as np
import subprocess
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import deepspeed
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter
from config import get_config
from models import build_model
from dataset import build_loader
from lr_scheduler import build_scheduler
from optimizer import set_weight_decay_and_lr
from logger import create_logger
from utils import load_pretrained, reduce_tensor, MyAverageMeter
from ddp_hooks import fp16_compress_hook
def parse_option():
parser = argparse.ArgumentParser(
'InternImage training and evaluation script', add_help=False)
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
parser.add_argument("--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+')
# easy config modification
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
parser.add_argument('--dataset', type=str, help='dataset name', default=None)
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece'
)
parser.add_argument('--pretrained', help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--output', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)'
)
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--save-ckpt-num', default=1, type=int)
parser.add_argument('--accumulation-steps', type=int, default=1, help="gradient accumulation steps")
# distributed training
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
parser.add_argument('--disable-grad-scalar', action='store_true', help='disable Grad Scalar')
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
def seed_everything(seed, rank):
seed = seed + rank
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
def save_config(config):
path = os.path.join(config.OUTPUT, "config.json")
with open(path, "w") as f:
f.write(config.dump())
logger.info(f"Full config saved to {path}")
def build_criterion(config):
if config.AUG.MIXUP > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif config.MODEL.LABEL_SMOOTHING > 0.:
criterion = LabelSmoothingCrossEntropy(
smoothing=config.MODEL.LABEL_SMOOTHING)
else:
criterion = torch.nn.CrossEntropyLoss()
return criterion
def scale_learning_rate(config, num_processes):
# linear scale the learning rate according to total batch size, may not be optimal
linear_scaled_lr = config.TRAIN.BASE_LR * \
config.DATA.BATCH_SIZE * num_processes / 512.0
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * \
config.DATA.BATCH_SIZE * num_processes / 512.0
linear_scaled_min_lr = config.TRAIN.MIN_LR * \
config.DATA.BATCH_SIZE * num_processes / 512.0
# gradient accumulation also need to scale the learning rate
if config.TRAIN.ACCUMULATION_STEPS > 1:
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
config.freeze()
logger.info('BASE_LR={}'.format(config.TRAIN.BASE_LR))
logger.info('WARMUP_LR={}'.format(config.TRAIN.WARMUP_LR))
logger.info('MIN_LR={}'.format(config.TRAIN.MIN_LR))
def log_model_statistic(model_wo_ddp):
n_parameters = sum(p.numel() for p in model_wo_ddp.parameters()
if p.requires_grad)
logger.info(f"number of params: {n_parameters/1e6} M")
if hasattr(model_wo_ddp, 'flops'):
flops = model_wo_ddp.flops()
logger.info(f"number of GFLOPs: {flops / 1e9}")
def get_parameter_groups(model, config):
skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
if hasattr(model, 'no_weight_decay_keywords'):
skip_keywords = model.no_weight_decay_keywords()
parameters = set_weight_decay_and_lr(
model,
config.TRAIN.WEIGHT_DECAY,
config.TRAIN.BASE_LR,
skip,
skip_keywords,
lr_layer_decay=config.TRAIN.LR_LAYER_DECAY,
lr_layer_decay_ratio=config.TRAIN.LR_LAYER_DECAY_RATIO,
freeze_backbone=config.TRAIN.OPTIMIZER.FREEZE_BACKBONE,
dcn_lr_mul=config.TRAIN.OPTIMIZER.DCN_LR_MUL,
)
return parameters
def get_optimizer_state_str(optimizer):
states = []
for param_group in optimizer.param_groups:
states.append(f'name={param_group["name"]} lr={param_group["lr"]} weight_decay={param_group["weight_decay"]}')
return '\n'.join(states)
def build_ds_config(config, args):
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
if opt_lower == 'adamw':
optimizer = {
"type": "AdamW",
"params": {
"lr": config.TRAIN.BASE_LR,
"eps": config.TRAIN.OPTIMIZER.EPS,
"betas": config.TRAIN.OPTIMIZER.BETAS,
"weight_decay": config.TRAIN.WEIGHT_DECAY
}
}
else:
return NotImplemented
ds_config = {
"train_micro_batch_size_per_gpu": config.DATA.BATCH_SIZE,
"optimizer": optimizer,
"fp16": {
"enabled": True,
"auto_cast": True,
"loss_scale": 1 if args.disable_grad_scalar else 0
},
"zero_optimization": {
"stage": 1,
},
"steps_per_print": 1e10,
"gradient_accumulation_steps": config.TRAIN.ACCUMULATION_STEPS,
"gradient_clipping": config.TRAIN.CLIP_GRAD,
}
return ds_config
@torch.no_grad()
def throughput(data_loader, model, logger):
model.eval()
for idx, (images, _) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
batch_size = images.shape[0]
for i in range(50):
model(images)
torch.cuda.synchronize()
logger.info(f"throughput averaged with 30 times")
tic1 = time.time()
for i in range(30):
model(images)
torch.cuda.synchronize()
tic2 = time.time()
logger.info(
f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}"
)
return
def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler):
model.train()
num_steps = len(data_loader)
batch_time = AverageMeter()
model_time = AverageMeter()
loss_meter = AverageMeter()
norm_meter = MyAverageMeter(300)
start = time.time()
end = time.time()
for idx, (samples, targets) in enumerate(data_loader):
iter_begin_time = time.time()
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
outputs = model(samples)
loss = criterion(outputs, targets)
model.backward(loss)
model.step()
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
lr_scheduler.step_update(epoch * num_steps + idx)
torch.cuda.synchronize()
loss_meter.update(loss.item(), targets.size(0))
norm_meter.update(optimizer._global_grad_norm)
batch_time.update(time.time() - end)
model_time.update(time.time() - iter_begin_time)
end = time.time()
if idx % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
etas = batch_time.avg * (num_steps - idx)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
f'model_time {model_time.val:.4f} ({model_time.avg:.4f})\t'
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f}/{norm_meter.var:.4f})\t'
f'mem {memory_used:.0f}MB')
epoch_time = time.time() - start
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
@torch.no_grad()
def eval_epoch(config, data_loader, model, epoch=None):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
output = model(images)
# convert 22k to 1k to evaluate
if output.size(-1) == 21841:
convert_file = './meta_data/map22kto1k.txt'
with open(convert_file, 'r') as f:
convert_list = [int(line) for line in f.readlines()]
output = output[:, convert_list]
# measure accuracy and record loss
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
f'Mem {memory_used:.0f}MB')
if epoch is not None:
logger.info(f'[Epoch:{epoch}] * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
else:
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
def train(config, ds_config):
# -------------- build ---------------- #
_, dataset_val, _, data_loader_train, data_loader_val, _, mixup_fn = build_loader(config)
model = build_model(config)
model.cuda()
if config.MODEL.PRETRAINED:
load_pretrained(config, model, logger)
logger.info(ds_config)
model, optimizer, _, _ = deepspeed.initialize(
config=ds_config,
model=model,
model_parameters=get_parameter_groups(model, config),
dist_init_required=False,
)
try:
model.register_comm_hook(state=None, hook=fp16_compress_hook)
logger.info('using fp16_compress_hook!')
except:
logger.info("cannot register fp16_compress_hook!")
model_without_ddp = model.module
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
criterion = build_criterion(config)
# -------------- resume ---------------- #
max_accuracy = 0.0
client_state = {}
if config.MODEL.RESUME == '' and config.TRAIN.AUTO_RESUME:
if os.path.exists(os.path.join(config.OUTPUT, 'latest')):
config.defrost()
config.MODEL.RESUME = config.OUTPUT
config.freeze()
tag = None
elif config.MODEL.RESUME:
config.MODEL.RESUME = os.path.dirname(config.MODEL.RESUME)
tag = os.path.basename(config.MODEL.RESUME)
if config.MODEL.RESUME:
logger.info('loading checkpoint from {}'.format(config.MODEL.RESUME))
_, client_state = model.load_checkpoint(load_dir=config.MODEL.RESUME, tag=tag)
logger.info(f'client_state={client_state.keys()}')
lr_scheduler.load_state_dict(client_state['custom_lr_scheduler'])
max_accuracy = client_state['max_accuracy']
# -------------- training ---------------- #
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
logger.info(str(model))
logger.info(get_optimizer_state_str(optimizer))
logger.info("Start training")
logger.info('max_accuracy: {}'.format(max_accuracy))
log_model_statistic(model_without_ddp)
start_time = time.time()
for epoch in range(client_state.get('epoch', config.TRAIN.START_EPOCH), config.TRAIN.EPOCHS):
data_loader_train.sampler.set_epoch(epoch)
train_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1:
model.save_checkpoint(
save_dir=config.OUTPUT,
tag=f'epoch{epoch}',
client_state={
'custom_lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config
}
)
if epoch % config.EVAL_FREQ == 0:
acc1, _, _ = eval_epoch(config, data_loader_val, model, epoch)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
if acc1 > max_accuracy:
model.save_checkpoint(
save_dir=config.OUTPUT,
tag='best',
client_state={
'custom_lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config
}
)
max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))
def eval(config):
_, _, _, _, data_loader_val, _, _ = build_loader(config)
model = build_model(config)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
model_wo_ddp = model.module
if config.MODEL.RESUME:
try:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
msg = model_wo_ddp.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
except:
try:
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
ckpt_dir = os.path.dirname(config.MODEL.RESUME)
tag = os.path.basename(config.MODEL.RESUME)
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir=ckpt_dir, tag=tag)
model_wo_ddp.load_state_dict(state_dict)
except:
checkpoint = torch.load(os.path.join(config.MODEL.RESUME, 'mp_rank_00_model_states.pt'), map_location='cpu')
model_wo_ddp.load_state_dict(checkpoint['module'])
elif config.MODEL.PRETRAINED:
load_pretrained(config, model_wo_ddp, logger)
if config.THROUGHPUT_MODE:
throughput(data_loader_val, model, logger)
eval_epoch(config, data_loader_val, model)
if __name__ == '__main__':
args, config = parse_option()
# init distributed env
if 'SLURM_PROCID' in os.environ:
print("\nDist init: SLURM")
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
config.defrost()
config.LOCAL_RANK = gpu
config.freeze()
world_size = int(os.environ["SLURM_NTASKS"])
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "29501"
node_list = os.environ["SLURM_NODELIST"]
addr = subprocess.getoutput(
f"scontrol show hostname {node_list} | head -n1")
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = addr
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(gpu)
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
os.environ['WORLD_SIZE'] = str(world_size)
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
else:
rank = -1
world_size = -1
torch.cuda.set_device(config.LOCAL_RANK)
torch.distributed.init_process_group(backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank)
torch.distributed.barrier()
os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT,
dist_rank=dist.get_rank(),
name=f"{config.MODEL.NAME}")
logger.info(config.dump())
if dist.get_rank() == 0: save_config(config)
scale_learning_rate(config, dist.get_world_size())
seed_everything(config.SEED, dist.get_rank())
if config.EVAL_MODE:
eval(config)
else:
train(config, build_ds_config(config, args))
......@@ -37,33 +37,22 @@ def build_optimizer(config, model):
if use_zero:
print(f"\nUse Zero!")
if opt_lower == 'sgd':
# an ugly implementation
# https://github.com/pytorch/pytorch/issues/71347
optimizer = ZeroRedundancyOptimizer(
parameters[0]['params'],
parameters,
optimizer_class=optim.SGD,
momentum=config.TRAIN.OPTIMIZER.MOMENTUM,
nesterov=True,
lr=config.TRAIN.BASE_LR,
weight_decay=config.TRAIN.WEIGHT_DECAY)
if len(parameters[1]['params']) > 0:
optimizer.add_param_group({
"params": parameters[1]['params'],
'weight_decay': 0.
})
elif opt_lower == 'adamw':
optimizer = ZeroRedundancyOptimizer(
parameters[0]['params'],
parameters,
optimizer_class=optim.AdamW,
eps=config.TRAIN.OPTIMIZER.EPS,
betas=config.TRAIN.OPTIMIZER.BETAS,
lr=config.TRAIN.BASE_LR,
weight_decay=config.TRAIN.WEIGHT_DECAY)
if len(parameters[1]['params']) > 0:
optimizer.add_param_group({
"params": parameters[1]['params'],
'weight_decay': 0.
})
else:
if opt_lower == 'sgd':
optimizer = optim.SGD(parameters,
......@@ -148,7 +137,7 @@ def set_weight_decay_and_lr(
lr_ratio_log[name] = (base_lr, ratio, wd, param.requires_grad)
else:
lr = base_lr
parameters.append({'params': [param], 'weight_decay': wd, 'lr': lr})
parameters.append({'params': [param], 'weight_decay': wd, 'lr': lr, 'name': name})
print('no decay params: {no_decay_name}')
if layerwise_lr:
......
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-12}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
--quotatype=spot \
${SRUN_ARGS} \
python -u main_deepspeed.py \
--cfg ${CONFIG} \
--local_rank 0 \
--data-path /mnt/lustre/share/images \
--output work_dirs_deepspeed ${@:4}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment