Commit b7535e7c authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1734 canceled with stages
This diff is collapsed.
"""
Scripts to register and load model, adopted from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_registry.py
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_factory.py
Hacked together by / Copyright 2023 Ross Wightman
"""
import torch
import os
from collections import OrderedDict
from copy import deepcopy
from typing import Any
import sys
import re
import fnmatch
from collections import defaultdict
from copy import deepcopy
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model default_cfgs
def register_pip_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if module:
all_models = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_default_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))
def is_model(model_name):
""" Check if a model name exists
"""
return model_name in _model_entrypoints
def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]
def list_modules():
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return list(sorted(modules))
def is_model_in_modules(model_name, module_names):
"""Check if a model exists within a subset of modules
Args:
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names)
def has_model_default_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key.
"""
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
return True
return False
def is_model_default_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist.
"""
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
return True
return False
def get_model_default_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist.
"""
if model_name in _model_default_cfgs:
return _model_default_cfgs[model_name].get(cfg_key, None)
else:
return None
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
print("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else:
print("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
model.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
def create_model(
model_name,
pretrained=False,
checkpoint_path='',
**kwargs):
create_fn = model_entrypoint(model_name)
model = create_fn(pretrained=pretrained, **kwargs)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)
return model
\ No newline at end of file
#!/bin/bash
DATA_PATH="/ImageNet/train"
MODEL=mamba_vision_T
BS=2
EXP=Test
LR=8e-4
WD=0.05
WR_LR=1e-6
DR=0.38
MESA=0.25
torchrun --nproc_per_node=2 --master_port=29501 train.py --mesa ${MESA} --input-size 3 224 224 --crop-pct=0.875 \
--data_dir=$DATA_PATH --model $MODEL --amp --weight-decay ${WD} --drop-path ${DR} --batch-size $BS --tag $EXP --lr $LR --warmup-lr $WR_LR
#!/bin/bash
DATA_PATH="/ImageNet/val"
BS=128
checkpoint='/model_weights/mambavision_tiny_1k.pth.tar'
python validate.py --model mamba_vision_T --checkpoint=$checkpoint --data_dir=$DATA_PATH --batch-size $BS --input-size 3 224 224 \
--num-gpu 2
\ No newline at end of file
from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler
from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler
""" Cosine Scheduler
Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class CosineLRScheduler(Scheduler):
"""
Cosine decay with restarts.
This is described in the paper https://arxiv.org/abs/1608.03983.
Inspiration from
https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
k = self.k_decay
if i < self.cycle_limit:
lrs = [
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
""" MultiStep LR Scheduler
Basic multi step LR schedule with warmup, noise.
"""
import torch
import bisect
from timm.scheduler.scheduler import Scheduler
from typing import List
class MultiStepLRScheduler(Scheduler):
"""
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
decay_t: List[int],
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def get_curr_decay_steps(self, t):
# find where in the array t goes,
# assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t+1)
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
""" Plateau Scheduler
Adapts PyTorch plateau scheduler and allows application of noise, warmup.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from .scheduler import Scheduler
class PlateauLRScheduler(Scheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self,
optimizer,
decay_rate=0.1,
patience_t=10,
verbose=True,
threshold=1e-4,
cooldown_t=0,
warmup_t=0,
warmup_lr_init=0,
lr_min=0,
mode='max',
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize=True,
):
super().__init__(
optimizer,
'lr',
noise_range_t=noise_range_t,
noise_type=noise_type,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
patience=patience_t,
factor=decay_rate,
verbose=verbose,
threshold=threshold,
cooldown=cooldown_t,
mode=mode,
min_lr=lr_min
)
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
self.restore_lr = None
def state_dict(self):
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
# override the base class step fn completely
def step(self, epoch, metric=None):
if epoch <= self.warmup_t:
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
super().update_groups(lrs)
else:
if self.restore_lr is not None:
# restore actual LR from before our last noise perturbation before stepping base
for i, param_group in enumerate(self.optimizer.param_groups):
param_group['lr'] = self.restore_lr[i]
self.restore_lr = None
self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self._is_apply_noise(epoch):
self._apply_noise(epoch)
def _apply_noise(self, epoch):
noise = self._calculate_noise(epoch)
# apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler
restore_lr = []
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group['lr'])
restore_lr.append(old_lr)
new_lr = old_lr + old_lr * noise
param_group['lr'] = new_lr
self.restore_lr = restore_lr
""" Polynomial Scheduler
Polynomial LR schedule with warmup, noise.
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
import logging
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class PolyLRScheduler(Scheduler):
""" Polynomial LR Scheduler w/ warmup, noise, and k-decay
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
power: float = 0.5,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.power = power
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
k = self.k_decay
if i < self.cycle_limit:
lrs = [
self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
from typing import Dict, Any
import torch
class Scheduler:
""" Parameter Scheduler Base Class
A scheduler base class that can be used to schedule any optimizer parameter groups.
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
The schedulers built on this should try to remain as stateless as possible (for simplicity).
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
Based on ideas from:
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize: bool = True) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}"
if initialize:
for i, group in enumerate(self.optimizer.param_groups):
if param_group_field not in group:
raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
group.setdefault(self._initial_param_group_field, group[param_group_field])
else:
for i, group in enumerate(self.optimizer.param_groups):
if self._initial_param_group_field not in group:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.metric = None # any point to having this for all?
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.update_groups(self.base_values)
def state_dict(self) -> Dict[str, Any]:
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
def get_epoch_values(self, epoch: int):
return None
def get_update_values(self, num_updates: int):
return None
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
values = self.get_epoch_values(epoch)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
values = self.get_update_values(num_updates)
if values is not None:
values = self._add_noise(values, num_updates)
self.update_groups(values)
def update_groups(self, values):
if not isinstance(values, (list, tuple)):
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
if 'lr_scale' in param_group:
param_group[self.param_group_field] = value * param_group['lr_scale']
else:
param_group[self.param_group_field] = value
def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
return apply_noise
def _calculate_noise(self, t) -> float:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
return noise
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise
""" Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman
"""
from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler
from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
n_iter = args.data_len // (args.batch_size * args.world_size)
tot_iter = num_epochs * n_iter
warmup_iters = args.warmup_epochs * n_iter
if getattr(args, 'lr_noise', None) is not None:
lr_noise = getattr(args, 'lr_noise')
if isinstance(lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in lr_noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = lr_noise * num_epochs
else:
noise_range = None
noise_args = dict(
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
cycle_args = dict(
cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
)
lr_scheduler = None
if args.sched == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=tot_iter,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=warmup_iters,
k_decay=getattr(args, 'lr_k_decay', 1.0),
t_in_epochs=args.lr_ep,
**cycle_args,
**noise_args,
)
cycle_length = lr_scheduler.get_cycle_length() // n_iter
num_epochs = cycle_length + args.cooldown_epochs
elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
t_in_epochs=True,
**cycle_args,
**noise_args,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
**noise_args,
)
elif args.sched == 'multistep':
lr_scheduler = MultiStepLRScheduler(
optimizer,
decay_t=args.decay_milestones,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
**noise_args,
)
elif args.sched == 'plateau':
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=0,
**noise_args,
)
elif args.sched == 'poly':
lr_scheduler = PolyLRScheduler(
optimizer,
power=args.decay_rate, # overloading 'decay_rate' as polynomial power
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
k_decay=getattr(args, 'lr_k_decay', 1.0),
**cycle_args,
**noise_args,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
return lr_scheduler, num_epochs
""" Step Scheduler
Basic step LR schedule with warmup, noise.
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import torch
from .scheduler import Scheduler
class StepLRScheduler(Scheduler):
"""
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
decay_t: float,
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
""" TanH Scheduler
TanH schedule with warmup, cycle/restarts, noise.
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class TanhLRScheduler(Scheduler):
"""
Hyberbolic-Tangent decay with restarts.
This is described in the paper https://arxiv.org/abs/1806.01593
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lb: float = -7.,
ub: float = 3.,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
assert lb < ub
assert cycle_limit >= 0
assert warmup_t >= 0
assert warmup_lr_init >= 0
self.lb = lb
self.ub = ub
self.t_initial = t_initial
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
if i < self.cycle_limit:
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
tr = t_curr / t_i
lrs = [
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
import torch
from tensorboardX import SummaryWriter
class TensorboardLogger(object):
def __init__(self, log_dir):
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
def flush(self):
self.writer.flush()
\ No newline at end of file
This diff is collapsed.
#!/bin/bash
DATA_PATH="/ImageNet/train"
MODEL=mamba_vision_T
BS=2
EXP=Test
LR=8e-4
WD=0.05
WR_LR=1e-6
DR=0.38
MESA=0.25
python train.py --mesa ${MESA} --input-size 3 224 224 --crop-pct=0.875 \
--data_dir=$DATA_PATH --model $MODEL --amp --weight-decay ${WD} --drop-path ${DR} --batch-size $BS --tag $EXP --lr $LR --warmup-lr $WR_LR
This diff is collapsed.
#!/usr/bin/env python3
""" ImageNet Validation Script
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import csv
import glob
import json
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from functools import partial
import torch
import torch.nn as nn
import torch.nn.parallel
from models.mamba_vision import *
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry, ParseKwargs
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
has_compile = hasattr(torch, 'compile')
_logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset (*deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR',
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--crop-mode', default=None, type=str,
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
help='enable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--results-format', default='csv', type=str,
help='Format for results file one of (csv, json) (default: csv).')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
use_amp = 'apex'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16')
use_amp = 'native'
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=in_chans,
global_pool=args.gp,
scriptable=args.torchscript,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True,
)
test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device)
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
model = torch.jit.script(model)
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
if use_amp == 'apex':
model = amp.initialize(model, opt_level='O1')
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
dataset = create_dataset(
root=root_dir,
name=args.dataset,
split=args.split,
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = [int(line.rstrip()) for line in f]
else:
valid_labels = None
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
pin_memory=args.pin_mem,
device=device,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.to(device)
input = input.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
with amp_autocast():
output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses,
top1=top1,
top5=top5
)
)
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
results = OrderedDict(
model=args.model,
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'],
)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def _try_run(args, initial_batch_size):
batch_size = initial_batch_size
results = OrderedDict()
error_str = 'Unknown'
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
if torch.cuda.is_available() and 'cuda' in args.device:
torch.cuda.empty_cache()
results = validate(args)
return results
except RuntimeError as e:
error_str = str(e)
_logger.error(f'"{error_str}" while running validation.')
if not check_batch_size_retry(error_str):
break
batch_size = decay_batch_step(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
_logger.error(f'{args.model} failed to validate ({error_str}).')
return results
_NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer']
def main():
setup_default_logging()
args = parser.parse_args()
model_cfgs = []
model_names = []
if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(
pretrained=True,
exclude_filters=_NON_IN1K_FILTERS,
)
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
model_names = list_models(
args.model,
pretrained=True,
)
model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model):
with open(args.model) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs):
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
initial_batch_size = args.batch_size
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
r = _try_run(args, initial_batch_size)
if 'error' in r:
continue
if args.checkpoint:
r['checkpoint'] = args.checkpoint
results.append(r)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
else:
if args.retry:
results = _try_run(args, args.batch_size)
else:
results = validate(args)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}')
def write_results(results_file, results, format='csv'):
with open(results_file, mode='w') as cf:
if format == 'json':
json.dump(results, cf, indent=4)
else:
if not isinstance(results, (list, tuple)):
results = [results]
if not results:
return
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()
#!/bin/bash
DATA_PATH="/ImageNet/val"
BS=128
checkpoint='/model_weights/mambavision_tiny_1k.pth.tar'
python validate.py --model mamba_vision_T --checkpoint=$checkpoint --data_dir=$DATA_PATH --batch-size $BS --input-size 3 224 224
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment