Commit 26e59280 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2674 failed with stages
in 0 seconds
# --------------------------------------------------------
# InternVL
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.scheduler.scheduler import Scheduler
from timm.scheduler.step_lr import StepLRScheduler
def build_scheduler(config, optimizer, n_iter_per_epoch):
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS *
n_iter_per_epoch)
lr_scheduler = None
if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_steps,
# t_mul=1.,
lr_min=config.TRAIN.MIN_LR,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
lr_scheduler = LinearLRScheduler(
optimizer,
t_initial=num_steps,
lr_min_rate=0.01,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=decay_steps,
decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
return lr_scheduler
class LinearLRScheduler(Scheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min_rate: float,
warmup_t=0,
warmup_lr_init=0.,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(optimizer,
param_group_field='lr',
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize)
self.t_initial = t_initial
self.lr_min_rate = lr_min_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t
for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
lrs = [
v - ((v - v * self.lr_min_rate) * (t / total_t))
for v in self.base_values
]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import argparse
import datetime
import os
import random
import subprocess
import time
from contextlib import suppress
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from config import get_config
from dataset import build_loader
from logger import create_logger
from lr_scheduler import build_scheduler
from models import build_model
from optimizer import build_optimizer
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import ApexScaler, AverageMeter, ModelEma, accuracy
from utils import MyAverageMeter
from utils import NativeScalerWithGradNormCount as NativeScaler
from utils import (auto_resume_helper, get_grad_norm, load_checkpoint,
load_ema_checkpoint, load_pretrained, reduce_tensor,
save_checkpoint)
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
# assert not has_apex, "The code is modified based on native amp"
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold):
return torch_version == 'parrots' or torch_version <= version_threshold
def parse_option():
parser = argparse.ArgumentParser(
'InternVL training and evaluation script', add_help=False)
parser.add_argument('--cfg',
type=str,
required=True,
metavar='FILE',
help='path to config file')
parser.add_argument(
'--opts',
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+')
# easy config modification
parser.add_argument('--batch-size',
type=int,
help='batch size for single GPU')
parser.add_argument('--dataset',
type=str,
help='dataset name',
default=None)
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip',
action='store_true',
help='use zipped dataset instead of folder dataset')
parser.add_argument(
'--cache-mode',
type=str,
default='part',
choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece'
)
parser.add_argument(
'--pretrained',
help=
'pretrained weight from checkpoint, could be imagenet22k pretrained weight'
)
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps',
type=int,
default=1,
help='gradient accumulation steps')
parser.add_argument(
'--use-checkpoint',
action='store_true',
help='whether to use gradient checkpointing to save memory')
parser.add_argument(
'--amp-opt-level',
type=str,
default='O1',
choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument(
'--output',
default='work_dirs',
type=str,
metavar='PATH',
help=
'root of output folder, the full path is <output>/<model_name>/<tag> (default: output)'
)
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval',
action='store_true',
help='Perform evaluation only')
parser.add_argument('--throughput',
action='store_true',
help='Test throughput only')
parser.add_argument('--save-ckpt-num', default=1, type=int)
parser.add_argument(
'--use-zero',
action='store_true',
help='whether to use ZeroRedundancyOptimizer (ZeRO) to save memory')
# distributed training
parser.add_argument('--local-rank',
type=int,
required=True,
help='local rank for DistributedDataParallel')
parser.add_argument('--launcher',
choices=['pytorch', 'slurm'],
default='pytorch')
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
@torch.no_grad()
def throughput(data_loader, model, logger):
model.eval()
for idx, (images, _) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
batch_size = images.shape[0]
for i in range(50):
model(images)
torch.cuda.synchronize()
logger.info(f'throughput averaged with 30 times')
tic1 = time.time()
for i in range(30):
model(images)
torch.cuda.synchronize()
tic2 = time.time()
logger.info(
f'batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}'
)
return
def main(config):
# prepare data loaders
dataset_train, dataset_val, dataset_test, data_loader_train, \
data_loader_val, data_loader_test, mixup_fn = build_loader(config)
# build runner
logger.info(f'Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}')
model = build_model(config)
model.cuda()
logger.info(str(model))
# build optimizer
optimizer = build_optimizer(config, model)
if config.AMP_OPT_LEVEL != 'O0':
config.defrost()
if has_native_amp:
config.native_amp = True
use_amp = 'native'
elif has_apex:
config.apex_amp = True
use_amp = 'apex'
else:
use_amp = None
logger.warning(
'Neither APEX or native Torch AMP is available, using float32. '
'Install NVIDA apex or upgrade to PyTorch 1.6')
config.freeze()
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if config.AMP_OPT_LEVEL != 'O0':
if use_amp == 'apex':
model, optimizer = amp.initialize(model,
optimizer,
opt_level=config.AMP_OPT_LEVEL)
loss_scaler = ApexScaler()
if config.LOCAL_RANK == 0:
logger.info(
'Using NVIDIA APEX AMP. Training in mixed precision.')
if use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if config.LOCAL_RANK == 0:
logger.info(
'Using native Torch AMP. Training in mixed precision.')
else:
if config.LOCAL_RANK == 0:
logger.info('AMP not enabled. Training in float32.')
# put model on gpus
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
# try:
# model.register_comm_hook(state=None, hook=fp16_compress_hook)
# logger.info('using fp16_compress_hook!')
# except:
# logger.info("cannot register fp16_compress_hook!")
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters()
if p.requires_grad)
logger.info(f'number of params: {n_parameters}')
if hasattr(model_without_ddp, 'flops'):
flops = model_without_ddp.flops()
logger.info(f'number of GFLOPs: {flops / 1e9}')
# build learning rate scheduler
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) \
if not config.EVAL_MODE else None
# build criterion
if config.AUG.MIXUP > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif config.MODEL.LABEL_SMOOTHING > 0.:
criterion = LabelSmoothingCrossEntropy(
smoothing=config.MODEL.LABEL_SMOOTHING)
else:
criterion = torch.nn.CrossEntropyLoss()
max_accuracy = 0.0
max_ema_accuracy = 0.0
# set auto resume
if config.MODEL.RESUME == '' and config.TRAIN.AUTO_RESUME:
resume_file = auto_resume_helper(config.OUTPUT)
if resume_file:
if config.MODEL.RESUME:
logger.warning(
f'auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}'
)
config.defrost()
config.MODEL.RESUME = resume_file
config.freeze()
logger.info(f'auto resuming from {resume_file}')
else:
logger.info(
f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
)
# set resume and pretrain
if config.MODEL.RESUME:
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
lr_scheduler, loss_scaler, logger)
if data_loader_val is not None:
if config.DATA.DATASET == 'imagenet-real':
filenames = dataset_val.filenames()
filenames = [os.path.basename(item) for item in filenames]
from dataset.imagenet_real import RealLabelsImagenet
real_labels = RealLabelsImagenet(filenames, real_json='meta_data/real.json')
acc1, acc5, loss = validate_real(config, data_loader_val, model, real_labels, amp_autocast=amp_autocast)
logger.info(
f'ReaL Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
else:
acc1, acc5, loss = validate(config, data_loader_val, model, amp_autocast=amp_autocast)
logger.info(
f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
elif config.MODEL.PRETRAINED:
load_pretrained(config, model_without_ddp, logger)
if data_loader_val is not None:
acc1, acc5, loss = validate(config, data_loader_val, model, amp_autocast=amp_autocast)
logger.info(
f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
# evaluate EMA
model_ema = None
if config.TRAIN.EMA.ENABLE:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(model, decay=config.TRAIN.EMA.DECAY)
print('Using EMA with decay = %.8f' % config.TRAIN.EMA.DECAY)
if config.MODEL.RESUME:
load_ema_checkpoint(config, model_ema, logger)
if config.DATA.DATASET == 'imagenet-real':
# assert only one gpu
assert dist.get_world_size() == 1, 'imagenet-real should test with one gpu'
filenames = dataset_val.filenames()
filenames = [os.path.basename(item) for item in filenames]
from dataset.imagenet_real import RealLabelsImagenet
real_labels = RealLabelsImagenet(filenames, real_json='meta_data/real.json')
acc1, acc5, loss = validate_real(config, data_loader_val, model_ema.ema, real_labels,
amp_autocast=amp_autocast)
logger.info(
f'ReaL Accuracy of the ema network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
else:
acc1, acc5, loss = validate(config, data_loader_val, model_ema.ema, amp_autocast=amp_autocast)
logger.info(
f'Accuracy of the ema network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
if config.THROUGHPUT_MODE:
throughput(data_loader_val, model, logger)
if config.EVAL_MODE:
return
# train
logger.info('Start training')
start_time = time.time()
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
data_loader_train.sampler.set_epoch(epoch)
train_one_epoch(config,
model,
criterion,
data_loader_train,
optimizer,
epoch,
mixup_fn,
lr_scheduler,
amp_autocast,
loss_scaler,
model_ema=model_ema)
if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)) and config.TRAIN.OPTIMIZER.USE_ZERO:
optimizer.consolidate_state_dict(to=0)
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
save_checkpoint(config,
epoch,
model_without_ddp,
max_accuracy,
optimizer,
lr_scheduler,
loss_scaler,
logger,
model_ema=model_ema)
if data_loader_val is not None and epoch % config.EVAL_FREQ == 0:
acc1, acc5, loss = validate(config, data_loader_val, model, epoch, amp_autocast=amp_autocast)
logger.info(
f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
if dist.get_rank() == 0 and acc1 > max_accuracy:
save_checkpoint(config,
epoch,
model_without_ddp,
max_accuracy,
optimizer,
lr_scheduler,
loss_scaler,
logger,
model_ema=model_ema,
best='best')
max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
if config.TRAIN.EMA.ENABLE:
acc1, acc5, loss = validate(config, data_loader_val,
model_ema.ema, epoch, amp_autocast=amp_autocast)
logger.info(
f'Accuracy of the ema network on the {len(dataset_val)} test images: {acc1:.1f}%'
)
if dist.get_rank() == 0 and acc1 > max_ema_accuracy:
save_checkpoint(config,
epoch,
model_without_ddp,
max_accuracy,
optimizer,
lr_scheduler,
loss_scaler,
logger,
model_ema=model_ema,
best='ema_best')
max_ema_accuracy = max(max_ema_accuracy, acc1)
logger.info(f'Max ema accuracy: {max_ema_accuracy:.2f}%')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))
def train_one_epoch(config,
model,
criterion,
data_loader,
optimizer,
epoch,
mixup_fn,
lr_scheduler,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None):
model.train()
optimizer.zero_grad()
num_steps = len(data_loader)
batch_time = AverageMeter()
model_time = AverageMeter()
loss_meter = AverageMeter()
norm_meter = MyAverageMeter(300)
start = time.time()
end = time.time()
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16
for idx, (samples, targets) in enumerate(data_loader):
iter_begin_time = time.time()
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if not obsolete_torch_version(TORCH_VERSION,
(1, 9)) and config.AMP_OPT_LEVEL != 'O0':
with amp_autocast(dtype=amp_type):
outputs = model(samples)
else:
with amp_autocast():
outputs = model(samples)
if config.TRAIN.ACCUMULATION_STEPS > 1:
if not obsolete_torch_version(
TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0':
with amp_autocast(dtype=amp_type):
loss = criterion(outputs, targets)
loss = loss / config.TRAIN.ACCUMULATION_STEPS
else:
with amp_autocast():
loss = criterion(outputs, targets)
loss = loss / config.TRAIN.ACCUMULATION_STEPS
if config.AMP_OPT_LEVEL != 'O0':
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(loss,
optimizer,
clip_grad=config.TRAIN.CLIP_GRAD,
parameters=model.parameters(),
create_graph=is_second_order,
update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
else:
loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
lr_scheduler.step_update(epoch * num_steps + idx)
else:
if not obsolete_torch_version(
TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0':
with amp_autocast(dtype=amp_type):
loss = criterion(outputs, targets)
else:
with amp_autocast():
loss = criterion(outputs, targets)
optimizer.zero_grad()
if config.AMP_OPT_LEVEL != 'O0':
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(loss,
optimizer,
clip_grad=config.TRAIN.CLIP_GRAD,
parameters=model.parameters(),
create_graph=is_second_order,
update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
if model_ema is not None:
model_ema.update(model)
else:
loss.backward()
if config.TRAIN.CLIP_GRAD:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
optimizer.step()
if model_ema is not None:
model_ema.update(model)
lr_scheduler.step_update(epoch * num_steps + idx)
torch.cuda.synchronize()
loss_meter.update(loss.item(), targets.size(0))
if grad_norm is not None:
norm_meter.update(grad_norm.item())
batch_time.update(time.time() - end)
model_time.update(time.time() - iter_begin_time)
end = time.time()
if idx % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
etas = batch_time.avg * (num_steps - idx)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
f'model_time {model_time.val:.4f} ({model_time.avg:.4f})\t'
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f}/{norm_meter.var:.4f})\t'
f'mem {memory_used:.0f}MB')
epoch_time = time.time() - start
logger.info(
f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}'
)
@torch.no_grad()
def validate_real(config, data_loader, model, real_labels, amp_autocast=suppress):
# https://github.com/baaivision/EVA/blob/master/EVA-01/eva/engine_for_finetuning.py#L195
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
if not obsolete_torch_version(TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0':
with amp_autocast(dtype=amp_type):
output = model(images)
else:
with amp_autocast():
output = model(images)
# convert 22k to 1k to evaluate
if output.size(-1) == 21841:
convert_file = './meta_data/map22kto1k.txt'
with open(convert_file, 'r') as f:
convert_list = [int(line) for line in f.readlines()]
output = output[:, convert_list]
real_labels.add_result(output)
# measure accuracy and record loss
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
f'Mem {memory_used:.0f}MB')
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
print('* ReaL Acc@1 {:.3f} Acc@5 {:.3f} loss {losses:.3f}'
.format(top1a, top5a, losses=loss_meter.avg))
return top1a, top5a, loss_meter.avg
@torch.no_grad()
def validate(config, data_loader, model, epoch=None, amp_autocast=suppress):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
if not obsolete_torch_version(TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0':
with amp_autocast(dtype=amp_type):
output = model(images)
else:
with amp_autocast():
output = model(images)
# convert 22k to 1k to evaluate
if output.size(-1) == 21841:
convert_file = './meta_data/map22kto1k.txt'
with open(convert_file, 'r') as f:
convert_list = [int(line) for line in f.readlines()]
output = output[:, convert_list]
if config.DATA.DATASET == 'imagenet_a':
from dataset.imagenet_a_r_indices import imagenet_a_mask
output = output[:, imagenet_a_mask]
elif config.DATA.DATASET == 'imagenet_r':
from dataset.imagenet_a_r_indices import imagenet_r_mask
output = output[:, imagenet_r_mask]
# measure accuracy and record loss
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
f'Mem {memory_used:.0f}MB')
if epoch is not None:
logger.info(
f'[Epoch:{epoch}] * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}'
)
else:
logger.info(
f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
if __name__ == '__main__':
_, config = parse_option()
if config.AMP_OPT_LEVEL != 'O0':
assert has_native_amp, 'Please update pytorch(1.6+) to support amp!'
# init distributed env
if _.launcher == 'slurm':
print('\nDist init: SLURM')
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
config.defrost()
config.LOCAL_RANK = gpu
config.freeze()
world_size = int(os.environ['SLURM_NTASKS'])
if 'MASTER_PORT' not in os.environ:
os.environ['MASTER_PORT'] = '29501'
node_list = os.environ['SLURM_NODELIST']
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(gpu)
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
os.environ['WORLD_SIZE'] = str(world_size)
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
else:
rank = -1
world_size = -1
torch.cuda.set_device(config.LOCAL_RANK)
torch.distributed.init_process_group(backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank)
torch.distributed.barrier()
seed = config.SEED + dist.get_rank()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
# linear scale the learning rate according to total batch size, may not be optimal
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
# gradient accumulation also need to scale the learning rate
if config.TRAIN.ACCUMULATION_STEPS > 1:
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
print(config.AMP_OPT_LEVEL, _.amp_opt_level)
config.freeze()
os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT,
dist_rank=dist.get_rank(),
name=f'{config.MODEL.NAME}')
if dist.get_rank() == 0:
path = os.path.join(config.OUTPUT, 'config.json')
with open(path, 'w') as f:
f.write(config.dump())
logger.info(f'Full config saved to {path}')
# print config
logger.info(config.dump())
main(config)
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"n01440764": 0,
"n01443537": 1,
"n01484850": 2,
"n01491361": 3,
"n01494475": 4,
"n01496331": 5,
"n01498041": 6,
"n01514668": 7,
"n01514859": 8,
"n01518878": 9,
"n01530575": 10,
"n01531178": 11,
"n01532829": 12,
"n01534433": 13,
"n01537544": 14,
"n01558993": 15,
"n01560419": 16,
"n01580077": 17,
"n01582220": 18,
"n01592084": 19,
"n01601694": 20,
"n01608432": 21,
"n01614925": 22,
"n01616318": 23,
"n01622779": 24,
"n01629819": 25,
"n01630670": 26,
"n01631663": 27,
"n01632458": 28,
"n01632777": 29,
"n01641577": 30,
"n01644373": 31,
"n01644900": 32,
"n01664065": 33,
"n01665541": 34,
"n01667114": 35,
"n01667778": 36,
"n01669191": 37,
"n01675722": 38,
"n01677366": 39,
"n01682714": 40,
"n01685808": 41,
"n01687978": 42,
"n01688243": 43,
"n01689811": 44,
"n01692333": 45,
"n01693334": 46,
"n01694178": 47,
"n01695060": 48,
"n01697457": 49,
"n01698640": 50,
"n01704323": 51,
"n01728572": 52,
"n01728920": 53,
"n01729322": 54,
"n01729977": 55,
"n01734418": 56,
"n01735189": 57,
"n01737021": 58,
"n01739381": 59,
"n01740131": 60,
"n01742172": 61,
"n01744401": 62,
"n01748264": 63,
"n01749939": 64,
"n01751748": 65,
"n01753488": 66,
"n01755581": 67,
"n01756291": 68,
"n01768244": 69,
"n01770081": 70,
"n01770393": 71,
"n01773157": 72,
"n01773549": 73,
"n01773797": 74,
"n01774384": 75,
"n01774750": 76,
"n01775062": 77,
"n01776313": 78,
"n01784675": 79,
"n01795545": 80,
"n01796340": 81,
"n01797886": 82,
"n01798484": 83,
"n01806143": 84,
"n01806567": 85,
"n01807496": 86,
"n01817953": 87,
"n01818515": 88,
"n01819313": 89,
"n01820546": 90,
"n01824575": 91,
"n01828970": 92,
"n01829413": 93,
"n01833805": 94,
"n01843065": 95,
"n01843383": 96,
"n01847000": 97,
"n01855032": 98,
"n01855672": 99,
"n01860187": 100,
"n01871265": 101,
"n01872401": 102,
"n01873310": 103,
"n01877812": 104,
"n01882714": 105,
"n01883070": 106,
"n01910747": 107,
"n01914609": 108,
"n01917289": 109,
"n01924916": 110,
"n01930112": 111,
"n01943899": 112,
"n01944390": 113,
"n01945685": 114,
"n01950731": 115,
"n01955084": 116,
"n01968897": 117,
"n01978287": 118,
"n01978455": 119,
"n01980166": 120,
"n01981276": 121,
"n01983481": 122,
"n01984695": 123,
"n01985128": 124,
"n01986214": 125,
"n01990800": 126,
"n02002556": 127,
"n02002724": 128,
"n02006656": 129,
"n02007558": 130,
"n02009229": 131,
"n02009912": 132,
"n02011460": 133,
"n02012849": 134,
"n02013706": 135,
"n02017213": 136,
"n02018207": 137,
"n02018795": 138,
"n02025239": 139,
"n02027492": 140,
"n02028035": 141,
"n02033041": 142,
"n02037110": 143,
"n02051845": 144,
"n02056570": 145,
"n02058221": 146,
"n02066245": 147,
"n02071294": 148,
"n02074367": 149,
"n02077923": 150,
"n02085620": 151,
"n02085782": 152,
"n02085936": 153,
"n02086079": 154,
"n02086240": 155,
"n02086646": 156,
"n02086910": 157,
"n02087046": 158,
"n02087394": 159,
"n02088094": 160,
"n02088238": 161,
"n02088364": 162,
"n02088466": 163,
"n02088632": 164,
"n02089078": 165,
"n02089867": 166,
"n02089973": 167,
"n02090379": 168,
"n02090622": 169,
"n02090721": 170,
"n02091032": 171,
"n02091134": 172,
"n02091244": 173,
"n02091467": 174,
"n02091635": 175,
"n02091831": 176,
"n02092002": 177,
"n02092339": 178,
"n02093256": 179,
"n02093428": 180,
"n02093647": 181,
"n02093754": 182,
"n02093859": 183,
"n02093991": 184,
"n02094114": 185,
"n02094258": 186,
"n02094433": 187,
"n02095314": 188,
"n02095570": 189,
"n02095889": 190,
"n02096051": 191,
"n02096177": 192,
"n02096294": 193,
"n02096437": 194,
"n02096585": 195,
"n02097047": 196,
"n02097130": 197,
"n02097209": 198,
"n02097298": 199,
"n02097474": 200,
"n02097658": 201,
"n02098105": 202,
"n02098286": 203,
"n02098413": 204,
"n02099267": 205,
"n02099429": 206,
"n02099601": 207,
"n02099712": 208,
"n02099849": 209,
"n02100236": 210,
"n02100583": 211,
"n02100735": 212,
"n02100877": 213,
"n02101006": 214,
"n02101388": 215,
"n02101556": 216,
"n02102040": 217,
"n02102177": 218,
"n02102318": 219,
"n02102480": 220,
"n02102973": 221,
"n02104029": 222,
"n02104365": 223,
"n02105056": 224,
"n02105162": 225,
"n02105251": 226,
"n02105412": 227,
"n02105505": 228,
"n02105641": 229,
"n02105855": 230,
"n02106030": 231,
"n02106166": 232,
"n02106382": 233,
"n02106550": 234,
"n02106662": 235,
"n02107142": 236,
"n02107312": 237,
"n02107574": 238,
"n02107683": 239,
"n02107908": 240,
"n02108000": 241,
"n02108089": 242,
"n02108422": 243,
"n02108551": 244,
"n02108915": 245,
"n02109047": 246,
"n02109525": 247,
"n02109961": 248,
"n02110063": 249,
"n02110185": 250,
"n02110341": 251,
"n02110627": 252,
"n02110806": 253,
"n02110958": 254,
"n02111129": 255,
"n02111277": 256,
"n02111500": 257,
"n02111889": 258,
"n02112018": 259,
"n02112137": 260,
"n02112350": 261,
"n02112706": 262,
"n02113023": 263,
"n02113186": 264,
"n02113624": 265,
"n02113712": 266,
"n02113799": 267,
"n02113978": 268,
"n02114367": 269,
"n02114548": 270,
"n02114712": 271,
"n02114855": 272,
"n02115641": 273,
"n02115913": 274,
"n02116738": 275,
"n02117135": 276,
"n02119022": 277,
"n02119789": 278,
"n02120079": 279,
"n02120505": 280,
"n02123045": 281,
"n02123159": 282,
"n02123394": 283,
"n02123597": 284,
"n02124075": 285,
"n02125311": 286,
"n02127052": 287,
"n02128385": 288,
"n02128757": 289,
"n02128925": 290,
"n02129165": 291,
"n02129604": 292,
"n02130308": 293,
"n02132136": 294,
"n02133161": 295,
"n02134084": 296,
"n02134418": 297,
"n02137549": 298,
"n02138441": 299,
"n02165105": 300,
"n02165456": 301,
"n02167151": 302,
"n02168699": 303,
"n02169497": 304,
"n02172182": 305,
"n02174001": 306,
"n02177972": 307,
"n02190166": 308,
"n02206856": 309,
"n02219486": 310,
"n02226429": 311,
"n02229544": 312,
"n02231487": 313,
"n02233338": 314,
"n02236044": 315,
"n02256656": 316,
"n02259212": 317,
"n02264363": 318,
"n02268443": 319,
"n02268853": 320,
"n02276258": 321,
"n02277742": 322,
"n02279972": 323,
"n02280649": 324,
"n02281406": 325,
"n02281787": 326,
"n02317335": 327,
"n02319095": 328,
"n02321529": 329,
"n02325366": 330,
"n02326432": 331,
"n02328150": 332,
"n02342885": 333,
"n02346627": 334,
"n02356798": 335,
"n02361337": 336,
"n02363005": 337,
"n02364673": 338,
"n02389026": 339,
"n02391049": 340,
"n02395406": 341,
"n02396427": 342,
"n02397096": 343,
"n02398521": 344,
"n02403003": 345,
"n02408429": 346,
"n02410509": 347,
"n02412080": 348,
"n02415577": 349,
"n02417914": 350,
"n02422106": 351,
"n02422699": 352,
"n02423022": 353,
"n02437312": 354,
"n02437616": 355,
"n02441942": 356,
"n02442845": 357,
"n02443114": 358,
"n02443484": 359,
"n02444819": 360,
"n02445715": 361,
"n02447366": 362,
"n02454379": 363,
"n02457408": 364,
"n02480495": 365,
"n02480855": 366,
"n02481823": 367,
"n02483362": 368,
"n02483708": 369,
"n02484975": 370,
"n02486261": 371,
"n02486410": 372,
"n02487347": 373,
"n02488291": 374,
"n02488702": 375,
"n02489166": 376,
"n02490219": 377,
"n02492035": 378,
"n02492660": 379,
"n02493509": 380,
"n02493793": 381,
"n02494079": 382,
"n02497673": 383,
"n02500267": 384,
"n02504013": 385,
"n02504458": 386,
"n02509815": 387,
"n02510455": 388,
"n02514041": 389,
"n02526121": 390,
"n02536864": 391,
"n02606052": 392,
"n02607072": 393,
"n02640242": 394,
"n02641379": 395,
"n02643566": 396,
"n02655020": 397,
"n02666196": 398,
"n02667093": 399,
"n02669723": 400,
"n02672831": 401,
"n02676566": 402,
"n02687172": 403,
"n02690373": 404,
"n02692877": 405,
"n02699494": 406,
"n02701002": 407,
"n02704792": 408,
"n02708093": 409,
"n02727426": 410,
"n02730930": 411,
"n02747177": 412,
"n02749479": 413,
"n02769748": 414,
"n02776631": 415,
"n02777292": 416,
"n02782093": 417,
"n02783161": 418,
"n02786058": 419,
"n02787622": 420,
"n02788148": 421,
"n02790996": 422,
"n02791124": 423,
"n02791270": 424,
"n02793495": 425,
"n02794156": 426,
"n02795169": 427,
"n02797295": 428,
"n02799071": 429,
"n02802426": 430,
"n02804414": 431,
"n02804610": 432,
"n02807133": 433,
"n02808304": 434,
"n02808440": 435,
"n02814533": 436,
"n02814860": 437,
"n02815834": 438,
"n02817516": 439,
"n02823428": 440,
"n02823750": 441,
"n02825657": 442,
"n02834397": 443,
"n02835271": 444,
"n02837789": 445,
"n02840245": 446,
"n02841315": 447,
"n02843684": 448,
"n02859443": 449,
"n02860847": 450,
"n02865351": 451,
"n02869837": 452,
"n02870880": 453,
"n02871525": 454,
"n02877765": 455,
"n02879718": 456,
"n02883205": 457,
"n02892201": 458,
"n02892767": 459,
"n02894605": 460,
"n02895154": 461,
"n02906734": 462,
"n02909870": 463,
"n02910353": 464,
"n02916936": 465,
"n02917067": 466,
"n02927161": 467,
"n02930766": 468,
"n02939185": 469,
"n02948072": 470,
"n02950826": 471,
"n02951358": 472,
"n02951585": 473,
"n02963159": 474,
"n02965783": 475,
"n02966193": 476,
"n02966687": 477,
"n02971356": 478,
"n02974003": 479,
"n02977058": 480,
"n02978881": 481,
"n02979186": 482,
"n02980441": 483,
"n02981792": 484,
"n02988304": 485,
"n02992211": 486,
"n02992529": 487,
"n02999410": 488,
"n03000134": 489,
"n03000247": 490,
"n03000684": 491,
"n03014705": 492,
"n03016953": 493,
"n03017168": 494,
"n03018349": 495,
"n03026506": 496,
"n03028079": 497,
"n03032252": 498,
"n03041632": 499,
"n03042490": 500,
"n03045698": 501,
"n03047690": 502,
"n03062245": 503,
"n03063599": 504,
"n03063689": 505,
"n03065424": 506,
"n03075370": 507,
"n03085013": 508,
"n03089624": 509,
"n03095699": 510,
"n03100240": 511,
"n03109150": 512,
"n03110669": 513,
"n03124043": 514,
"n03124170": 515,
"n03125729": 516,
"n03126707": 517,
"n03127747": 518,
"n03127925": 519,
"n03131574": 520,
"n03133878": 521,
"n03134739": 522,
"n03141823": 523,
"n03146219": 524,
"n03160309": 525,
"n03179701": 526,
"n03180011": 527,
"n03187595": 528,
"n03188531": 529,
"n03196217": 530,
"n03197337": 531,
"n03201208": 532,
"n03207743": 533,
"n03207941": 534,
"n03208938": 535,
"n03216828": 536,
"n03218198": 537,
"n03220513": 538,
"n03223299": 539,
"n03240683": 540,
"n03249569": 541,
"n03250847": 542,
"n03255030": 543,
"n03259280": 544,
"n03271574": 545,
"n03272010": 546,
"n03272562": 547,
"n03290653": 548,
"n03291819": 549,
"n03297495": 550,
"n03314780": 551,
"n03325584": 552,
"n03337140": 553,
"n03344393": 554,
"n03345487": 555,
"n03347037": 556,
"n03355925": 557,
"n03372029": 558,
"n03376595": 559,
"n03379051": 560,
"n03384352": 561,
"n03388043": 562,
"n03388183": 563,
"n03388549": 564,
"n03393912": 565,
"n03394916": 566,
"n03400231": 567,
"n03404251": 568,
"n03417042": 569,
"n03424325": 570,
"n03425413": 571,
"n03443371": 572,
"n03444034": 573,
"n03445777": 574,
"n03445924": 575,
"n03447447": 576,
"n03447721": 577,
"n03450230": 578,
"n03452741": 579,
"n03457902": 580,
"n03459775": 581,
"n03461385": 582,
"n03467068": 583,
"n03476684": 584,
"n03476991": 585,
"n03478589": 586,
"n03481172": 587,
"n03482405": 588,
"n03483316": 589,
"n03485407": 590,
"n03485794": 591,
"n03492542": 592,
"n03494278": 593,
"n03495258": 594,
"n03496892": 595,
"n03498962": 596,
"n03527444": 597,
"n03529860": 598,
"n03530642": 599,
"n03532672": 600,
"n03534580": 601,
"n03535780": 602,
"n03538406": 603,
"n03544143": 604,
"n03584254": 605,
"n03584829": 606,
"n03590841": 607,
"n03594734": 608,
"n03594945": 609,
"n03595614": 610,
"n03598930": 611,
"n03599486": 612,
"n03602883": 613,
"n03617480": 614,
"n03623198": 615,
"n03627232": 616,
"n03630383": 617,
"n03633091": 618,
"n03637318": 619,
"n03642806": 620,
"n03649909": 621,
"n03657121": 622,
"n03658185": 623,
"n03661043": 624,
"n03662601": 625,
"n03666591": 626,
"n03670208": 627,
"n03673027": 628,
"n03676483": 629,
"n03680355": 630,
"n03690938": 631,
"n03691459": 632,
"n03692522": 633,
"n03697007": 634,
"n03706229": 635,
"n03709823": 636,
"n03710193": 637,
"n03710637": 638,
"n03710721": 639,
"n03717622": 640,
"n03720891": 641,
"n03721384": 642,
"n03724870": 643,
"n03729826": 644,
"n03733131": 645,
"n03733281": 646,
"n03733805": 647,
"n03742115": 648,
"n03743016": 649,
"n03759954": 650,
"n03761084": 651,
"n03763968": 652,
"n03764736": 653,
"n03769881": 654,
"n03770439": 655,
"n03770679": 656,
"n03773504": 657,
"n03775071": 658,
"n03775546": 659,
"n03776460": 660,
"n03777568": 661,
"n03777754": 662,
"n03781244": 663,
"n03782006": 664,
"n03785016": 665,
"n03786901": 666,
"n03787032": 667,
"n03788195": 668,
"n03788365": 669,
"n03791053": 670,
"n03792782": 671,
"n03792972": 672,
"n03793489": 673,
"n03794056": 674,
"n03796401": 675,
"n03803284": 676,
"n03804744": 677,
"n03814639": 678,
"n03814906": 679,
"n03825788": 680,
"n03832673": 681,
"n03837869": 682,
"n03838899": 683,
"n03840681": 684,
"n03841143": 685,
"n03843555": 686,
"n03854065": 687,
"n03857828": 688,
"n03866082": 689,
"n03868242": 690,
"n03868863": 691,
"n03871628": 692,
"n03873416": 693,
"n03874293": 694,
"n03874599": 695,
"n03876231": 696,
"n03877472": 697,
"n03877845": 698,
"n03884397": 699,
"n03887697": 700,
"n03888257": 701,
"n03888605": 702,
"n03891251": 703,
"n03891332": 704,
"n03895866": 705,
"n03899768": 706,
"n03902125": 707,
"n03903868": 708,
"n03908618": 709,
"n03908714": 710,
"n03916031": 711,
"n03920288": 712,
"n03924679": 713,
"n03929660": 714,
"n03929855": 715,
"n03930313": 716,
"n03930630": 717,
"n03933933": 718,
"n03935335": 719,
"n03937543": 720,
"n03938244": 721,
"n03942813": 722,
"n03944341": 723,
"n03947888": 724,
"n03950228": 725,
"n03954731": 726,
"n03956157": 727,
"n03958227": 728,
"n03961711": 729,
"n03967562": 730,
"n03970156": 731,
"n03976467": 732,
"n03976657": 733,
"n03977966": 734,
"n03980874": 735,
"n03982430": 736,
"n03983396": 737,
"n03991062": 738,
"n03992509": 739,
"n03995372": 740,
"n03998194": 741,
"n04004767": 742,
"n04005630": 743,
"n04008634": 744,
"n04009552": 745,
"n04019541": 746,
"n04023962": 747,
"n04026417": 748,
"n04033901": 749,
"n04033995": 750,
"n04037443": 751,
"n04039381": 752,
"n04040759": 753,
"n04041544": 754,
"n04044716": 755,
"n04049303": 756,
"n04065272": 757,
"n04067472": 758,
"n04069434": 759,
"n04070727": 760,
"n04074963": 761,
"n04081281": 762,
"n04086273": 763,
"n04090263": 764,
"n04099969": 765,
"n04111531": 766,
"n04116512": 767,
"n04118538": 768,
"n04118776": 769,
"n04120489": 770,
"n04125021": 771,
"n04127249": 772,
"n04131690": 773,
"n04133789": 774,
"n04136333": 775,
"n04141076": 776,
"n04141327": 777,
"n04141975": 778,
"n04146614": 779,
"n04147183": 780,
"n04149813": 781,
"n04152593": 782,
"n04153751": 783,
"n04154565": 784,
"n04162706": 785,
"n04179913": 786,
"n04192698": 787,
"n04200800": 788,
"n04201297": 789,
"n04204238": 790,
"n04204347": 791,
"n04208210": 792,
"n04209133": 793,
"n04209239": 794,
"n04228054": 795,
"n04229816": 796,
"n04235860": 797,
"n04238763": 798,
"n04239074": 799,
"n04243546": 800,
"n04251144": 801,
"n04252077": 802,
"n04252225": 803,
"n04254120": 804,
"n04254680": 805,
"n04254777": 806,
"n04258138": 807,
"n04259630": 808,
"n04263257": 809,
"n04264628": 810,
"n04265275": 811,
"n04266014": 812,
"n04270147": 813,
"n04273569": 814,
"n04275548": 815,
"n04277352": 816,
"n04285008": 817,
"n04286575": 818,
"n04296562": 819,
"n04310018": 820,
"n04311004": 821,
"n04311174": 822,
"n04317175": 823,
"n04325704": 824,
"n04326547": 825,
"n04328186": 826,
"n04330267": 827,
"n04332243": 828,
"n04335435": 829,
"n04336792": 830,
"n04344873": 831,
"n04346328": 832,
"n04347754": 833,
"n04350905": 834,
"n04355338": 835,
"n04355933": 836,
"n04356056": 837,
"n04357314": 838,
"n04366367": 839,
"n04367480": 840,
"n04370456": 841,
"n04371430": 842,
"n04371774": 843,
"n04372370": 844,
"n04376876": 845,
"n04380533": 846,
"n04389033": 847,
"n04392985": 848,
"n04398044": 849,
"n04399382": 850,
"n04404412": 851,
"n04409515": 852,
"n04417672": 853,
"n04418357": 854,
"n04423845": 855,
"n04428191": 856,
"n04429376": 857,
"n04435653": 858,
"n04442312": 859,
"n04443257": 860,
"n04447861": 861,
"n04456115": 862,
"n04458633": 863,
"n04461696": 864,
"n04462240": 865,
"n04465501": 866,
"n04467665": 867,
"n04476259": 868,
"n04479046": 869,
"n04482393": 870,
"n04483307": 871,
"n04485082": 872,
"n04486054": 873,
"n04487081": 874,
"n04487394": 875,
"n04493381": 876,
"n04501370": 877,
"n04505470": 878,
"n04507155": 879,
"n04509417": 880,
"n04515003": 881,
"n04517823": 882,
"n04522168": 883,
"n04523525": 884,
"n04525038": 885,
"n04525305": 886,
"n04532106": 887,
"n04532670": 888,
"n04536866": 889,
"n04540053": 890,
"n04542943": 891,
"n04548280": 892,
"n04548362": 893,
"n04550184": 894,
"n04552348": 895,
"n04553703": 896,
"n04554684": 897,
"n04557648": 898,
"n04560804": 899,
"n04562935": 900,
"n04579145": 901,
"n04579432": 902,
"n04584207": 903,
"n04589890": 904,
"n04590129": 905,
"n04591157": 906,
"n04591713": 907,
"n04592741": 908,
"n04596742": 909,
"n04597913": 910,
"n04599235": 911,
"n04604644": 912,
"n04606251": 913,
"n04612504": 914,
"n04613696": 915,
"n06359193": 916,
"n06596364": 917,
"n06785654": 918,
"n06794110": 919,
"n06874185": 920,
"n07248320": 921,
"n07565083": 922,
"n07579787": 923,
"n07583066": 924,
"n07584110": 925,
"n07590611": 926,
"n07613480": 927,
"n07614500": 928,
"n07615774": 929,
"n07684084": 930,
"n07693725": 931,
"n07695742": 932,
"n07697313": 933,
"n07697537": 934,
"n07711569": 935,
"n07714571": 936,
"n07714990": 937,
"n07715103": 938,
"n07716358": 939,
"n07716906": 940,
"n07717410": 941,
"n07717556": 942,
"n07718472": 943,
"n07718747": 944,
"n07720875": 945,
"n07730033": 946,
"n07734744": 947,
"n07742313": 948,
"n07745940": 949,
"n07747607": 950,
"n07749582": 951,
"n07753113": 952,
"n07753275": 953,
"n07753592": 954,
"n07754684": 955,
"n07760859": 956,
"n07768694": 957,
"n07802026": 958,
"n07831146": 959,
"n07836838": 960,
"n07860988": 961,
"n07871810": 962,
"n07873807": 963,
"n07875152": 964,
"n07880968": 965,
"n07892512": 966,
"n07920052": 967,
"n07930864": 968,
"n07932039": 969,
"n09193705": 970,
"n09229709": 971,
"n09246464": 972,
"n09256479": 973,
"n09288635": 974,
"n09332890": 975,
"n09399592": 976,
"n09421951": 977,
"n09428293": 978,
"n09468604": 979,
"n09472597": 980,
"n09835506": 981,
"n10148035": 982,
"n10565667": 983,
"n11879895": 984,
"n11939491": 985,
"n12057211": 986,
"n12144580": 987,
"n12267677": 988,
"n12620546": 989,
"n12768682": 990,
"n12985857": 991,
"n12998815": 992,
"n13037406": 993,
"n13040303": 994,
"n13044778": 995,
"n13052670": 996,
"n13054560": 997,
"n13133613": 998,
"n15075141": 999
}
359
368
460
475
486
492
496
514
516
525
547
548
556
563
575
641
648
723
733
765
801
826
852
858
878
896
900
905
908
910
935
946
947
994
999
1003
1005
1010
1027
1029
1048
1055
1064
1065
1069
1075
1079
1081
1085
1088
1093
1106
1143
1144
1145
1147
1168
1171
1178
1187
1190
1197
1205
1216
1223
1230
1236
1241
1245
1257
1259
1260
1267
1268
1269
1271
1272
1273
1277
1303
1344
1349
1355
1357
1384
1388
1391
1427
1429
1432
1437
1450
1461
1462
1474
1502
1503
1512
1552
1555
1577
1584
1587
1589
1599
1615
1616
1681
1692
1701
1716
1729
1757
1759
1764
1777
1786
1822
1841
1842
1848
1850
1856
1860
1861
1864
1876
1897
1898
1910
1913
1918
1922
1928
1932
1935
1947
1951
1953
1970
1977
1979
2001
2017
2067
2081
2087
2112
2128
2135
2147
2174
2175
2176
2177
2178
2181
2183
2184
2187
2189
2190
2191
2192
2193
2197
2202
2203
2206
2208
2209
2211
2212
2213
2214
2215
2216
2217
2219
2222
2223
2224
2225
2226
2227
2228
2229
2230
2236
2238
2240
2241
2242
2243
2244
2245
2247
2248
2249
2250
2251
2252
2255
2256
2257
2262
2263
2264
2265
2266
2268
2270
2271
2272
2273
2275
2276
2279
2280
2281
2282
2285
2289
2292
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2309
2310
2312
2313
2314
2315
2316
2318
2319
2321
2322
2326
2329
2330
2331
2332
2334
2335
2336
2337
2338
2339
2341
2342
2343
2344
2346
2348
2349
2351
2352
2353
2355
2357
2358
2359
2360
2364
2365
2368
2369
2377
2382
2383
2385
2397
2398
2400
2402
2405
2412
2421
2428
2431
2432
2433
2436
2441
2445
2450
2453
2454
2465
2469
2532
2533
2538
2544
2547
2557
2565
2578
2612
2658
2702
2722
2731
2738
2741
2747
2810
2818
2833
2844
2845
2867
2874
2882
2884
2888
2889
3008
3012
3019
3029
3033
3042
3091
3106
3138
3159
3164
3169
3280
3296
3311
3318
3320
3324
3330
3366
3375
3381
3406
3419
3432
3434
3435
3493
3495
3503
3509
3511
3513
3517
3521
3526
3546
3554
3600
3601
3606
3612
3613
3616
3622
3623
3627
3632
3634
3636
3638
3644
3646
3649
3650
3651
3656
3663
3673
3674
3689
3690
3702
3733
3769
3971
3974
4065
4068
4073
4102
4136
4140
4151
4159
4165
4207
4219
4226
4249
4256
4263
4270
4313
4321
4378
4386
4478
4508
4512
4536
4542
4550
4560
4562
4570
4571
4572
4583
4588
4594
4604
4608
4623
4634
4636
4646
4651
4652
4686
4688
4691
4699
4724
4727
4737
4770
4774
4789
4802
4807
4819
4880
4886
4908
4927
4931
4936
4964
4976
4993
5028
5033
5043
5046
5096
5111
5114
5131
5132
5183
5199
5235
5275
5291
5293
5294
5343
5360
5362
5364
5390
5402
5418
5428
5430
5437
5443
5473
5484
5486
5505
5507
5508
5510
5567
5578
5580
5584
5606
5613
5629
5672
5676
5692
5701
5760
5769
5770
5779
5814
5850
5871
5893
5911
5949
5954
6005
6006
6012
6017
6023
6024
6040
6050
6054
6087
6105
6157
6235
6237
6256
6259
6286
6291
6306
6339
6341
6343
6379
6383
6393
6405
6479
6511
6517
6541
6561
6608
6611
6615
6678
6682
6707
6752
6798
6850
6880
6885
6890
6920
6981
7000
7009
7038
7049
7050
7052
7073
7078
7098
7111
7165
7198
7204
7280
7283
7286
7287
7293
7294
7305
7318
7341
7346
7354
7382
7427
7428
7435
7445
7450
7455
7467
7469
7497
7502
7506
7514
7523
7651
7661
7664
7672
7679
7685
7696
7730
7871
7873
7895
7914
7915
7920
7934
7935
7949
8009
8036
8051
8065
8074
8090
8112
8140
8164
8168
8178
8182
8198
8212
8216
8230
8242
8288
8289
8295
8318
8352
8368
8371
8375
8376
8401
8416
8419
8436
8460
8477
8478
8482
8498
8500
8539
8543
8552
8555
8580
8584
8586
8594
8598
8601
8606
8610
8611
8622
8627
8639
8649
8650
8653
8654
8667
8672
8673
8674
8676
8684
8720
8723
8750
8753
8801
8815
8831
8835
8842
8845
8858
8897
8916
8951
8954
8959
8970
8976
8981
8983
8989
8991
8993
9019
9039
9042
9043
9056
9057
9070
9087
9098
9106
9130
9131
9155
9171
9183
9198
9199
9201
9204
9211
9220
9224
9228
9249
9259
9270
9278
9294
9299
9309
9321
9344
9351
9375
9376
9381
9391
9400
9404
9440
9448
9463
9474
9501
9504
9513
9514
9544
9566
9575
9607
9608
9623
9632
9638
9642
9655
9673
9739
9751
9759
9766
9777
9801
9819
9838
9878
9923
9955
9960
9962
9969
9996
10009
10030
10039
10051
10072
10074
10077
10093
10096
10108
10117
10120
10123
10157
10225
10275
10303
10306
10313
10314
10331
10336
10337
10412
10422
10450
10462
10464
10486
10518
10521
10522
10531
10533
10534
10550
10558
10573
10582
10585
10588
10611
10625
10634
10637
10676
10682
10725
10775
10781
10782
10806
10836
10839
10847
10858
10870
10880
10883
10907
10913
10920
10935
10946
10950
10951
10956
10998
11002
11017
11022
11024
11026
11044
11054
11094
11109
11136
11136
11167
11185
11220
11222
11241
11254
11258
11278
11305
11310
11330
11366
11376
11388
11391
11400
11406
11436
11448
11465
11468
11472
11477
11482
11483
11506
11535
11557
11565
11574
11583
11593
11610
11611
11618
11620
11639
11642
11663
11673
11688
11708
11709
11715
11720
11725
11728
11742
11759
11770
11836
11838
11855
11875
11877
11883
11888
11895
11916
11922
11929
11943
11951
11979
11983
12213
12228
12238
12240
12241
12246
12282
12348
12368
12372
12421
12559
12565
12574
12687
12754
12767
12777
12779
12811
12831
12834
12835
12842
12846
12848
12849
12855
12857
12872
12937
12970
13016
13037
13045
13058
13084
13085
13087
13093
13133
13181
13229
13405
13443
13613
13689
13697
13708
13748
13803
13981
14050
14058
14218
14245
14255
14263
14293
14323
14366
14388
14393
14437
14441
14964
15730
16742
18035
18203
18533
18790
19100
20017
20460
21024
21043
21161
21169
21179
21194
21198
21367
21815
This source diff could not be displayed because it is too large. You can view the blob instead.
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from .build import build_model
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from .clip_vit import CLIPViT
from .intern_vit_6b import InternViT6B
def build_model(config):
model_type = config.MODEL.TYPE
if model_type == 'intern_vit_6b':
model = InternViT6B(
num_classes=config.MODEL.NUM_CLASSES,
patch_size=config.MODEL.INTERN_VIT_6B.PATCH_SIZE,
img_size=config.DATA.IMG_SIZE,
pretrain_size=config.MODEL.INTERN_VIT_6B.PRETRAIN_SIZE,
qkv_bias=config.MODEL.INTERN_VIT_6B.QKV_BIAS,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
embed_dim=config.MODEL.INTERN_VIT_6B.EMBED_DIM,
num_heads=config.MODEL.INTERN_VIT_6B.NUM_HEADS,
mlp_ratio=config.MODEL.INTERN_VIT_6B.MLP_RATIO,
init_values=config.MODEL.INTERN_VIT_6B.INIT_VALUES,
qk_normalization=config.MODEL.INTERN_VIT_6B.QK_NORMALIZATION,
depth=config.MODEL.INTERN_VIT_6B.DEPTH,
use_flash_attn=config.MODEL.INTERN_VIT_6B.USE_FLASH_ATTN,
with_cp=config.TRAIN.USE_CHECKPOINT,
freeze_vit=config.MODEL.INTERN_VIT_6B.FREEZE_VIT,
pretrained=config.MODEL.INTERN_VIT_6B.PRETRAINED,
cls_target=config.MODEL.INTERN_VIT_6B.CLS_TARGET,
norm_type=config.MODEL.INTERN_VIT_6B.NORM_TYPE,
)
elif model_type == 'clip_vit':
model = CLIPViT(
patch_size=config.MODEL.CLIP_VIT.PATCH_SIZE,
img_size=config.DATA.IMG_SIZE,
pretrain_size=config.MODEL.CLIP_VIT.PRETRAIN_SIZE,
embed_dim=config.MODEL.CLIP_VIT.EMBED_DIM,
num_heads=config.MODEL.CLIP_VIT.NUM_HEADS,
mlp_ratio=config.MODEL.CLIP_VIT.MLP_RATIO,
depth=config.MODEL.CLIP_VIT.DEPTH,
with_cp=config.TRAIN.USE_CHECKPOINT,
freeze_vit=config.MODEL.CLIP_VIT.FREEZE_VIT,
pretrained=config.MODEL.CLIP_VIT.PRETRAINED,
cls_target=config.MODEL.CLIP_VIT.CLS_TARGET,
)
else:
raise NotImplementedError(f'Unkown model: {model_type}')
return model
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from transformers import CLIPModel
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None, out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_attn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv, pos_q, pos_k = x, 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
x = x.squeeze(1)
return x
class CLIPViT(nn.Module):
def __init__(self, patch_size=14, img_size=336, pretrain_size=336, embed_dim=1024, num_heads=16,
mlp_ratio=4, depth=48, with_cp=True, freeze_vit=True, cls_target='cls_patch_concat',
num_classes=1000, pretrained=None):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.pretrain_size = pretrain_size
self.img_size = img_size
self.patch_size = patch_size
self.cls_target = cls_target
self.depth = depth
self.mlp_ratio = mlp_ratio
self.with_cp = with_cp
model = CLIPModel.from_pretrained(pretrained)
model.post_layernorm = nn.Identity()
self.model = model.vision_model
if freeze_vit:
_freeze_params(self)
if cls_target == 'cls_patch_concat':
self.norm = nn.SyncBatchNorm(embed_dim * 2, eps=1e-6)
self.head = nn.Linear(embed_dim * 2, num_classes) if num_classes > 0 else nn.Identity()
elif cls_target == 'attention_pooling':
self.attn_pooling = AttentionPoolingBlock(
dim=embed_dim, num_heads=num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=embed_dim)
self.norm = nn.SyncBatchNorm(embed_dim, eps=1e-6)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
else:
raise NotImplementedError
if type(self.head) != nn.Identity:
self.head.weight.data.normal_(mean=0.0, std=0.01)
self.head.bias.data.zero_()
@property
def dtype(self):
return self.model.embeddings.patch_embedding.weight.dtype
def forward_features(self, x):
x = x.type(self.dtype)
x = self.model(x)
x = x.last_hidden_state
return x
def forward(self, x):
x = self.forward_features(x)
if self.cls_target == 'cls_patch_concat':
x = torch.cat((x[:, 0, :], x[:, 1:, :].mean(dim=1)), dim=-1)
elif self.cls_target == 'attention_pooling':
x = self.attn_pooling(x)
else:
raise NotImplementedError
x = self.norm(x)
x = self.head(x)
return x
@torch.jit.ignore
def lr_decay_keywords(self, decay_ratio=0.95):
lr_ratios = {}
# layers
for idx in range(self.depth):
tag = 'layers.{}.'.format(idx)
decay = 1.0 * (decay_ratio ** (self.depth - idx))
lr_ratios[tag] = decay
# patch_embedding
lr_ratios['patch_embedding'] = 1.0 * (decay_ratio ** (self.depth + 1))
lr_ratios['position_embedding'] = 1.0 * (decay_ratio ** (self.depth + 1))
lr_ratios['pre_layrnorm'] = 1.0 * (decay_ratio ** (self.depth + 1))
return lr_ratios
import torch
import torch.nn as nn
from einops import rearrange
try: # v1
from flash_attn.flash_attn_interface import \
flash_attn_unpadded_qkvpacked_func
except: # v2
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
return output, None
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple
try:
from .flash_attention import FlashAttention
has_flash_attn = True
except:
print('FlashAttention is not installed.')
has_flash_attn = False
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None, out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_attn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv, pos_q, pos_k = x, 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
x = x.squeeze(1)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
try:
from apex.normalization import FusedRMSNorm
RMSNorm = FusedRMSNorm # noqa
print('Discovered apex.normalization.FusedRMSNorm - will use it instead of RMSNorm')
except ImportError:
# using the normal RMSNorm
pass
except Exception:
print('discovered apex but it failed to load, falling back to RMSNorm')
pass
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
self.force_fp32 = force_fp32
@torch.cuda.amp.autocast(enabled=False)
def forward(self, x):
if self.force_fp32:
output_type = x.dtype
out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
return out.to(dtype=output_type)
else:
out = x.mul_(self.gamma) if self.inplace else x * self.gamma
return out
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
causal=False, norm_layer=nn.LayerNorm, qk_normalization=False):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.use_flash_attn = use_flash_attn
if use_flash_attn:
self.causal = causal
self.inner_attn = FlashAttention(attention_dropout=attn_drop)
self.qk_normalization = qk_normalization
self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
def _naive_attn(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
qkv = self.qkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
)
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
outs = self.proj_drop(outs)
return outs
def forward(self, x):
x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
bias=True, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, with_cp=False,
qk_normalization=False, layerscale_force_fp32=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
qk_normalization=qk_normalization)
self.ls1 = LayerScale(dim, init_values=init_values,
force_fp32=layerscale_force_fp32) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values,
force_fp32=layerscale_force_fp32) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
if self.with_cp:
return checkpoint.checkpoint(_inner_forward, x)
else:
return _inner_forward(x)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x, **kwargs):
x = self.proj(x)
_, _, H, W = x.shape
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x, H, W
class InternViT6B(nn.Module):
def __init__(self, in_chans=3, patch_size=14, img_size=224, pretrain_size=224, qkv_bias=False, drop_path_rate=0.0,
embed_dim=3200, num_heads=25, mlp_ratio=4, init_values=0.1, qk_normalization=True, depth=48,
use_flash_attn=True, with_cp=True, layerscale_force_fp32=False, freeze_vit=True,
cls_target='cls_patch_concat', num_classes=1000, attn_pool_num_heads=16, clip_embed_dim=768,
norm_type='rms', pretrained=None):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.pretrain_size = pretrain_size
self.drop_path_rate = drop_path_rate
self.img_size = img_size
self.patch_size = patch_size
self.cls_target = cls_target
self.depth = depth
use_flash_attn = use_flash_attn and has_flash_attn
if use_flash_attn and not has_flash_attn:
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
use_flash_attn = [use_flash_attn] * depth if not isinstance(use_flash_attn, list) else use_flash_attn
if norm_type == 'rms':
norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
elif norm_type == 'ln':
norm_layer_for_blocks = partial(nn.LayerNorm, eps=1e-6)
else:
raise NotImplementedError
self.norm_layer_for_blocks = norm_layer_for_blocks
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.num_patches = num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Identity()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
norm_layer=norm_layer_for_blocks,
drop_path=dpr[i], init_values=init_values, attn_drop=0.,
use_flash_attn=use_flash_attn[i],
with_cp=with_cp,
qk_normalization=qk_normalization,
layerscale_force_fp32=layerscale_force_fp32)
for i in range(depth)])
if cls_target == 'clip_projector':
self.clip_projector = AttentionPoolingBlock(
dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
self.init_weights(pretrained)
if freeze_vit:
_freeze_params(self)
if cls_target == 'cls_patch_concat':
self.norm = nn.SyncBatchNorm(embed_dim * 2, eps=1e-6)
self.head = nn.Linear(embed_dim * 2, num_classes) if num_classes > 0 else nn.Identity()
elif cls_target == 'attention_pooling':
self.attn_pooling = AttentionPoolingBlock(
dim=embed_dim, num_heads=num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=embed_dim)
self.norm = nn.SyncBatchNorm(embed_dim, eps=1e-6)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
elif cls_target == 'clip_projector':
self.norm = nn.SyncBatchNorm(clip_embed_dim, eps=1e-6)
self.head = nn.Linear(clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
else:
raise NotImplementedError
if type(self.head) != nn.Identity:
self.head.weight.data.normal_(mean=0.0, std=0.01)
self.head.bias.data.zero_()
def init_weights(self, pretrained=None):
print(f'pretrained: {pretrained}')
def resize_pos_embed(pos_embed, H, W):
cls = pos_embed[:, :1, :]
pos_embed = pos_embed[:, 1:, :].reshape(
1, self.pretrain_size // 14, self.pretrain_size // 14, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
reshape(1, -1, H * W).permute(0, 2, 1)
pos_embed = torch.cat([cls, pos_embed], dim=1)
return pos_embed
if isinstance(pretrained, str):
checkpoint = torch.load(pretrained, map_location='cpu')
if 'module' in checkpoint:
checkpoint = checkpoint['module']
# resize pos_embed
pos_embed = checkpoint['pos_embed']
checkpoint['pos_embed'] = resize_pos_embed(
pos_embed, self.img_size // self.patch_size, self.img_size // self.patch_size)
# resize patch_embed
patch_embed = checkpoint['patch_embed.proj.weight']
checkpoint['patch_embed.proj.weight'] = F.interpolate(
patch_embed, size=(self.patch_size, self.patch_size),
mode='bicubic', align_corners=False)
message = self.load_state_dict(checkpoint, strict=False)
print(message)
@property
def dtype(self):
return self.patch_embed.proj.weight.dtype
def forward_features(self, x):
x, _, _ = self.patch_embed(x.type(self.dtype))
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
for idx, blk in enumerate(self.blocks):
x = blk(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.cls_target == 'cls_patch_concat':
x = torch.cat((x[:, 0, :], x[:, 1:, :].mean(dim=1)), dim=-1)
elif self.cls_target == 'attention_pooling':
x = self.attn_pooling(x)
elif self.cls_target == 'clip_projector':
x = self.clip_projector(x)
else:
raise NotImplementedError
x = self.norm(x)
x = self.head(x)
return x
@torch.jit.ignore
def lr_decay_keywords(self, decay_ratio=0.95):
lr_ratios = {}
# blocks
for idx in range(self.depth):
tag = 'blocks.{}.'.format(idx)
decay = 1.0 * (decay_ratio ** (self.depth - idx))
lr_ratios[tag] = decay
# patch_embed
lr_ratios['patch_embed'] = 1.0 * (decay_ratio ** (self.depth + 1))
lr_ratios['pos_embed'] = 1.0 * (decay_ratio ** (self.depth + 1))
lr_ratios['cls_token'] = 1.0 * (decay_ratio ** (self.depth + 1))
return lr_ratios
# --------------------------------------------------------
# InternVL
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from torch import optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
def build_optimizer(config, model):
"""
Build optimizer, set weight decay of normalization to 0 by default.
"""
skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
if hasattr(model, 'no_weight_decay_keywords'):
skip_keywords = model.no_weight_decay_keywords()
parameters = set_weight_decay_and_lr(
model,
config.TRAIN.WEIGHT_DECAY,
config.TRAIN.BASE_LR,
skip,
skip_keywords,
lr_layer_decay=config.TRAIN.LR_LAYER_DECAY,
lr_layer_decay_ratio=config.TRAIN.LR_LAYER_DECAY_RATIO,
freeze_backbone=config.TRAIN.OPTIMIZER.FREEZE_BACKBONE,
dcn_lr_mul=config.TRAIN.OPTIMIZER.DCN_LR_MUL,
)
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
optimizer = None
use_zero = config.TRAIN.OPTIMIZER.USE_ZERO
if use_zero:
print(f'\nUse Zero!')
if opt_lower == 'sgd':
# an ugly implementation
# this problem is fixed after torch 1.12
# https://github.com/pytorch/pytorch/issues/71347
# before 1.12, we could only pass list to zero optimizer, so we first pass parameters[0] with its lr and weight decay,
# then we add other parameter via parameter group.
optimizer = ZeroRedundancyOptimizer(
parameters[0]['params'],
optimizer_class=optim.SGD,
momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
lr=parameters[0]['lr'], weight_decay=parameters[0]['weight_decay']
)
if len(parameters) > 1:
for param_group in parameters[1:]:
optimizer.add_param_group(param_group)
elif opt_lower == 'adamw':
optimizer = ZeroRedundancyOptimizer(
parameters[0]['params'],
optimizer_class=optim.AdamW,
eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
lr=parameters[0]['lr'], weight_decay=parameters[0]['weight_decay']
)
if len(parameters) > 1:
for param_group in parameters[1:]:
optimizer.add_param_group(param_group)
else:
if opt_lower == 'sgd':
optimizer = optim.SGD(parameters,
momentum=config.TRAIN.OPTIMIZER.MOMENTUM,
nesterov=True,
lr=config.TRAIN.BASE_LR,
weight_decay=config.TRAIN.WEIGHT_DECAY)
elif opt_lower == 'sgd_linear_probing':
optimizer = optim.SGD(parameters,
momentum=0.9,
nesterov=False,
lr=config.TRAIN.BASE_LR,
weight_decay=0)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters,
eps=config.TRAIN.OPTIMIZER.EPS,
betas=config.TRAIN.OPTIMIZER.BETAS,
lr=config.TRAIN.BASE_LR,
weight_decay=config.TRAIN.WEIGHT_DECAY)
else:
raise NotImplementedError
return optimizer
def check_keywords_in_name(name, keywords=()):
isin = False
for keyword in keywords:
if keyword in name:
isin = True
return isin
def check_keywords_in_dict(name, keywords_dict):
for k, v in keywords_dict.items():
if k in name:
return v
return None
def set_weight_decay_and_lr(
model,
weight_decay,
base_lr,
skip_list=(),
skip_keywords=(),
lr_layer_decay=None,
lr_layer_decay_ratio=None,
freeze_backbone=None,
dcn_lr_mul=None,
layerwise_lr=True,
):
parameters = []
no_decay_name = []
lr_ratio_log = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if freeze_backbone:
for i in freeze_backbone:
if f'levels.{i}' in name:
param.requires_grad = False
# 1. check wd
if len(param.shape) == 1 or name.endswith('.bias') or (
name in skip_list) or check_keywords_in_name(name, skip_keywords):
wd = 0.
no_decay_name.append(name)
else:
wd = weight_decay
if lr_layer_decay:
print('layer-wise lr decay is used !')
assert hasattr(model, 'lr_decay_keywords')
lr_ratio_keywards = model.lr_decay_keywords(lr_layer_decay_ratio)
# 2. check lr
ratio = check_keywords_in_dict(name, lr_ratio_keywards)
if ratio is not None:
lr = ratio * base_lr
else:
lr = base_lr
# dcn lr
if dcn_lr_mul is not None:
if 'offset' in name or 'attention_weights' in name or 'center_feature_scale_proj' in name or 'alpha_beta' in name:
lr = dcn_lr_mul * lr
lr_ratio_log[name] = (base_lr, ratio, wd, param.requires_grad)
else:
lr = base_lr
parameters.append({'params': [param], 'weight_decay': wd, 'lr': lr, 'name': name})
print('no decay params: {no_decay_name}')
if layerwise_lr:
print('lr_ratio_params:')
for k, v in lr_ratio_log.items():
print(k, v)
return parameters
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-10}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
--quotatype=reserved \
${SRUN_ARGS} \
python -u main.py \
--cfg ${CONFIG} \
--accumulation-steps 1 \
--local-rank 0 \
--output work_dirs ${@:4}
# --------------------------------------------------------
# InternVL
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import math
import os
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from timm.utils import get_state_dict
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def load_ema_checkpoint(config, model_ema, logger):
logger.info(
f'==============> Resuming form {config.MODEL.RESUME}....................'
)
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME,
map_location='cpu',
check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
assert isinstance(checkpoint, dict)
if 'model_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['model_ema'].items():
if model_ema.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
msg = model_ema.ema.load_state_dict(new_state_dict, strict=False)
logger.info(msg)
logger.info('Loaded state_dict_ema')
else:
logger.warning(
'Failed to find state_dict_ema, starting from loaded model weights'
)
max_accuracy_ema = 0
if 'max_accuracy_ema' in checkpoint:
max_accuracy_ema = checkpoint['max_accuracy_ema']
if 'ema_decay' in checkpoint:
model_ema.decay = checkpoint['ema_decay']
return max_accuracy_ema
def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
logger.info(
f'==============> Resuming form {config.MODEL.RESUME}....................'
)
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME,
map_location='cpu',
check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
print('resuming model')
model_checkpoint = checkpoint['model']
msg = model.load_state_dict(model_checkpoint, strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
if optimizer is not None:
print('resuming optimizer')
try:
optimizer.load_state_dict(checkpoint['optimizer'])
except:
print('resume optimizer failed')
if lr_scheduler is not None:
print('resuming lr_scheduler')
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != 'O0' and checkpoint['config'].AMP_OPT_LEVEL != 'O0':
scaler.load_state_dict(checkpoint['amp'])
logger.info(
f"=> loaded successfully {config.MODEL.RESUME} (epoch {checkpoint['epoch']})"
)
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def load_pretrained(config, model, logger):
logger.info(
f'==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......'
)
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
state_dict = checkpoint
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'module' in checkpoint:
state_dict = checkpoint['module']
first_key = list(state_dict.keys())[0]
# delete teacher weights
if 'student' in first_key or 'teacher' in first_key:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'student_proj' in k:
continue
if 'student' in k:
new_k = k.replace('student.', '')
new_state_dict[new_k] = v
state_dict = new_state_dict
# weights from sim
if 'mask_token' in first_key:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'mm_dcnv3' in k:
continue
if 'dcnv3' not in k and 'clip_projector' not in k:
continue
new_k = k.replace('dcnv3.', '')
new_state_dict[new_k] = v
new_state_dict['fc_norm.weight'] = state_dict[
'clip.classifier_ln.weight']
new_state_dict['fc_norm.bias'] = state_dict['clip.classifier_ln.bias']
new_state_dict['head.weight'] = state_dict['clip.classifier.weight']
new_state_dict['head.bias'] = state_dict['clip.classifier.bias']
state_dict = new_state_dict
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_position_index' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_coords_table' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k]
for k in attn_mask_keys:
del state_dict[k]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [
k for k in state_dict.keys() if 'relative_position_bias_table' in k
]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = model.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {k}, passing......')
else:
if L1 != L2:
# bicubic interpolate relative_position_bias_table if not match
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
# bicubic interpolate absolute_pos_embed if not match
absolute_pos_embed_keys = [
k for k in state_dict.keys() if 'absolute_pos_embed' in k
]
for k in absolute_pos_embed_keys:
# dpe
absolute_pos_embed_pretrained = state_dict[k]
absolute_pos_embed_current = model.state_dict()[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logger.warning(f'Error in loading {k}, passing......')
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained,
size=(S2, S2),
mode='bicubic')
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
state_dict[k] = absolute_pos_embed_pretrained_resized
# check classifier, if not match, then re-init classifier to zero
if 'head.bias' in state_dict:
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if config.TRAIN.RAND_INIT_FT_HEAD:
model.head.weight.data = model.head.weight.data * 0.001
model.head.bias.data = model.head.bias.data * 0.001
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f'Error in loading classifier head, re-init classifier head to 0')
elif Nc1 == 21841 and Nc2 == 1000:
logger.info('loading ImageNet-22K weight to ImageNet-1K ......')
map22kto1k_path = 'meta_data/map22kto1k.txt'
logger.info(map22kto1k_path)
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
logger.info(f'=> loaded successfully {config.MODEL.PRETRAINED}')
del checkpoint
torch.cuda.empty_cache()
def convert_22k_head_to_1k(model, logger):
head_weight = model.module.head.weight
head_bias = model.module.head.bias
Nc1 = head_bias.shape[0]
if Nc1 == 21841:
logger.info('converting ImageNet-22K head to ImageNet-1K ......')
map22kto1k_path = 'meta_data/map22kto1k.txt'
logger.info(map22kto1k_path)
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
model.module.head.weight = torch.nn.Parameter(head_weight[map22kto1k, :])
model.module.head.bias = torch.nn.Parameter(head_bias[map22kto1k])
else:
logger.warning(f'Error in converting classifier head')
return model
def save_checkpoint(config,
epoch,
model,
max_accuracy,
optimizer,
lr_scheduler,
scaler,
logger,
model_ema=None,
max_accuracy_ema=None,
ema_decay=None,
model_ems=None,
max_accuracy_ems=None,
ems_model_num=None,
best=None):
save_state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config
}
if model_ema is not None:
save_state['model_ema'] = get_state_dict(model_ema)
if max_accuracy_ema is not None:
save_state['max_accuracy_ema'] = max_accuracy_ema
if ema_decay is not None:
save_state['ema_decay'] = ema_decay
if model_ems is not None:
save_state['model_ems'] = get_state_dict(model_ems)
if max_accuracy_ems is not None:
save_state['max_accuracy_ems'] = max_accuracy_ems
if ems_model_num is not None:
save_state['ems_model_num'] = ems_model_num
if config.AMP_OPT_LEVEL != 'O0':
# save_state['amp'] = amp.state_dict()
save_state['amp'] = scaler.state_dict()
if best is None:
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
else:
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{best}.pth')
logger.info(f'{save_path} saving......')
torch.save(save_state, save_path)
logger.info(f'{save_path} saved !!!')
if dist.get_rank() == 0 and isinstance(epoch, int):
to_del = epoch - config.SAVE_CKPT_NUM * config.SAVE_FREQ
old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{to_del}.pth')
if os.path.exists(old_ckpt):
os.remove(old_ckpt)
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f'All checkpoints founded in {output_dir}: {checkpoints}')
if len(checkpoints) > 0:
latest_checkpoint = max(
[os.path.join(output_dir, d) for d in checkpoints],
key=os.path.getmtime)
print(f'The latest checkpoint founded: {latest_checkpoint}')
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
# https://github.com/facebookresearch/ConvNeXt/blob/main/utils.py
class NativeScalerWithGradNormCount:
state_dict_key = 'amp_scaler'
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self,
loss,
optimizer,
clip_grad=None,
parameters=None,
create_graph=False,
update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
class MyAverageMeter(object):
"""Computes and stores the average and current value."""
def __init__(self, max_len=-1):
self.val_list = []
self.count = []
self.max_len = max_len
self.val = 0
self.avg = 0
self.var = 0
def update(self, val):
self.val = val
self.avg = 0
self.var = 0
if not math.isnan(val) and not math.isinf(val):
self.val_list.append(val)
if self.max_len > 0 and len(self.val_list) > self.max_len:
self.val_list = self.val_list[-self.max_len:]
if len(self.val_list) > 0:
self.avg = np.mean(np.array(self.val_list))
self.var = np.std(np.array(self.val_list))
This source diff could not be displayed because it is too large. You can view the blob instead.
=======
Credits
=======
* `Mehdi Cherti <https://github.com/mehdidc>`_
* `Romain Beaumont <https://github.com/rom1504>`_
.. highlight:: shell
============
Contributing
============
Contributions are welcome, and they are greatly appreciated! Every little bit
helps, and credit will always be given.
You can contribute in many ways:
Types of Contributions
----------------------
Report Bugs
~~~~~~~~~~~
Report bugs at https://github.com/LAION-AI/CLIP_benchmark/issues.
If you are reporting a bug, please include:
* Your operating system name and version.
* Any details about your local setup that might be helpful in troubleshooting.
* Detailed steps to reproduce the bug.
Fix Bugs
~~~~~~~~
Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
wanted" is open to whoever wants to implement it.
Implement Features
~~~~~~~~~~~~~~~~~~
Look through the GitHub issues for features. Anything tagged with "enhancement"
and "help wanted" is open to whoever wants to implement it.
Write Documentation
~~~~~~~~~~~~~~~~~~~
CLIP Benchmark could always use more documentation, whether as part of the
official CLIP Benchmark docs, in docstrings, or even on the web in blog posts,
articles, and such.
Submit Feedback
~~~~~~~~~~~~~~~
The best way to send feedback is to file an issue at https://github.com/LAION-AI/CLIP_benchmark/issues.
If you are proposing a feature:
* Explain in detail how it would work.
* Keep the scope as narrow as possible, to make it easier to implement.
* Remember that this is a volunteer-driven project, and that contributions
are welcome :)
Get Started!
------------
Ready to contribute? Here's how to set up `clip_benchmark` for local development.
1. Fork the `clip_benchmark` repo on GitHub.
2. Clone your fork locally::
$ git clone git@github.com:your_name_here/clip_benchmark.git
3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::
$ mkvirtualenv clip_benchmark
$ cd clip_benchmark/
$ python setup.py develop
4. Create a branch for local development::
$ git checkout -b name-of-your-bugfix-or-feature
Now you can make your changes locally.
5. When you're done making changes, check that your changes pass flake8 and the
tests, including testing other Python versions with tox::
$ flake8 clip_benchmark tests
$ python setup.py test or pytest
$ tox
To get flake8 and tox, just pip install them into your virtualenv.
6. Commit your changes and push your branch to GitHub::
$ git add .
$ git commit -m "Your detailed description of your changes."
$ git push origin name-of-your-bugfix-or-feature
7. Submit a pull request through the GitHub website.
Pull Request Guidelines
-----------------------
Before you submit a pull request, check that it meets these guidelines:
1. The pull request should include tests.
2. If the pull request adds functionality, the docs should be updated. Put
your new functionality into a function with a docstring, and add the
feature to the list in README.rst.
3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check
https://travis-ci.com/mehdidc/clip_benchmark/pull_requests
and make sure that the tests pass for all supported Python versions.
Tips
----
To run a subset of tests::
$ python -m unittest tests.test_clip_benchmark
Deploying
---------
A reminder for the maintainers on how to deploy.
Make sure all your changes are committed (including an entry in HISTORY.rst).
Then run::
$ bump2version patch # possible: major / minor / patch
$ git push
$ git push --tags
Travis will then deploy to PyPI if tests pass.
## History
### 1.4.0
* Fix silent webdataset error-handling
* Added support for wds/voc2007_multilabel
* default to float32
* add mscoco generative benchmark
### 1.3.0
* update flickr8k results, solve issue #48, thanks to @orchidmajumder
* Evaluate multiple models/datasets/languages using the CLI directly
* Support Japanese CLIP by rinna
* Add arabic imagenet
* updating CuPL prompts with more generated sentences + ensembled with openAI prompts
* put model in eval mode before evaluation
* Webdataset updates
* Make verbose the default
### 1.2.0
* Added support for loading webdatasets
### 1.1.0
* Added better support for multilingual eval
* Added better support for linear probing
* Added support for CuPL prompts
### 1.0.1
* pypi description as markdown
### 1.0.0
* Actual first release on PyPI.
### 0.1.0
* First release on PyPI.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment