Commit b7535e7c authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1734 canceled with stages
#!/usr/bin/env python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn as nn
from timm.models.registry import register_model
import math
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
from timm.models._builder import resolve_pretrained_cfg
try:
from timm.models._builder import _update_default_kwargs as update_args
except:
from timm.models._builder import _update_default_model_kwargs as update_args
from timm.models.vision_transformer import Mlp, PatchEmbed
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from einops import rearrange, repeat
from .registry import register_pip_model
from pathlib import Path
def _cfg(url='', **kwargs):
return {'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.875,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
**kwargs
}
default_cfgs = {
'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
crop_pct=0.98,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
crop_pct=0.93,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center')
}
def window_partition(x, window_size):
"""
Args:
x: (B, C, H, W)
window_size: window size
h_w: Height of window
w_w: Width of window
Returns:
local window features (num_windows*B, window_size*window_size, C)
"""
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: local window features (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of image
W: Width of image
Returns:
x: (B, C, H, W)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
return x
def _load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
if len(err_msg) > 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def _load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = torch.load(filename, map_location=map_location)
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
_load_state_dict(model, state_dict, strict, logger)
return checkpoint
class Downsample(nn.Module):
"""
Down-sampling block"
"""
def __init__(self,
dim,
keep_dim=False,
):
"""
Args:
dim: feature size dimension.
norm_layer: normalization layer.
keep_dim: bool argument for maintaining the resolution.
"""
super().__init__()
if keep_dim:
dim_out = dim
else:
dim_out = 2 * dim
self.reduction = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
)
def forward(self, x):
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""
Patch embedding block"
"""
def __init__(self, in_chans=3, in_dim=64, dim=96):
"""
Args:
in_chans: number of input channels.
dim: feature size dimension.
"""
# in_dim = 1
super().__init__()
self.proj = nn.Identity()
self.conv_down = nn.Sequential(
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(in_dim, eps=1e-4),
nn.ReLU(),
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(dim, eps=1e-4),
nn.ReLU()
)
def forward(self, x):
x = self.proj(x)
x = self.conv_down(x)
return x
class ConvBlock(nn.Module):
def __init__(self, dim,
drop_path=0.,
layer_scale=None,
kernel_size=3):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
self.act1 = nn.GELU(approximate= 'tanh')
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
self.layer_scale = layer_scale
if layer_scale is not None and type(layer_scale) in [int, float]:
self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
self.layer_scale = True
else:
self.layer_scale = False
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.norm2(x)
if self.layer_scale:
x = x * self.gamma.view(1, -1, 1, 1)
x = input + self.drop_path(x)
return x
class MambaVisionMixer(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True,
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
self.x_proj = nn.Linear(
self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj.bias._no_reinit = True
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner//2,
).contiguous()
A_log = torch.log(A)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.conv1d_x = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
self.conv1d_z = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
_, seqlen, _ = hidden_states.shape
xz = self.in_proj(hidden_states)
xz = rearrange(xz, "b l d -> b d l")
x, z = xz.chunk(2, dim=1)
A = -torch.exp(self.A_log.float())
x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
y = selective_scan_fn(x,
dt,
A,
B,
C,
self.D.float(),
z=None,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=None)
y = torch.cat([y, z], dim=1)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = True
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
counter,
transformer_blocks,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Mlp_block=Mlp,
layer_scale=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
if counter in transformer_blocks:
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
else:
self.mixer = MambaVisionMixer(d_model=dim,
d_state=8,
d_conv=3,
expand=1
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class MambaVisionLayer(nn.Module):
"""
MambaVision layer"
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
conv=False,
downsample=True,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
layer_scale=None,
layer_scale_conv=None,
transformer_blocks = [],
):
"""
Args:
dim: feature size dimension.
depth: number of layers in each stage.
window_size: window size in each stage.
conv: bool argument for conv stage flag.
downsample: bool argument for down-sampling.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path: drop path rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
transformer_blocks: list of transformer blocks.
"""
super().__init__()
self.conv = conv
self.transformer_block = False
if conv:
self.blocks = nn.ModuleList([ConvBlock(dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale_conv)
for i in range(depth)])
self.transformer_block = False
else:
self.blocks = nn.ModuleList([Block(dim=dim,
counter=i,
transformer_blocks=transformer_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale)
for i in range(depth)])
self.transformer_block = True
self.downsample = None if not downsample else Downsample(dim=dim)
self.do_gt = False
self.window_size = window_size
def forward(self, x):
_, _, H, W = x.shape
if self.transformer_block:
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
if pad_r > 0 or pad_b > 0:
x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
_, _, Hp, Wp = x.shape
else:
Hp, Wp = H, W
x = window_partition(x, self.window_size)
for _, blk in enumerate(self.blocks):
x = blk(x)
if self.transformer_block:
x = window_reverse(x, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
if self.downsample is None:
return x
return self.downsample(x)
class MambaVision(nn.Module):
"""
MambaVision,
"""
def __init__(self,
dim,
in_dim,
depths,
window_size,
mlp_ratio,
num_heads,
drop_path_rate=0.2,
in_chans=3,
num_classes=1000,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
layer_scale=None,
layer_scale_conv=None,
**kwargs):
"""
Args:
dim: feature size dimension.
depths: number of layers in each stage.
window_size: window size in each stage.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
drop_path_rate: drop path rate.
in_chans: number of input channels.
num_classes: number of classes.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
"""
super().__init__()
num_features = int(dim * 2 ** (len(depths) - 1))
self.num_classes = num_classes
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.levels = nn.ModuleList()
for i in range(len(depths)):
conv = True if (i == 0 or i == 1) else False
level = MambaVisionLayer(dim=int(dim * 2 ** i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
conv=conv,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=(i < 3),
layer_scale=layer_scale,
layer_scale_conv=layer_scale_conv,
transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
)
self.levels.append(level)
self.norm = nn.BatchNorm2d(num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, LayerNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'rpb'}
def forward_features(self, x):
x = self.patch_embed(x)
for level in self.levels:
x = level(x)
x = self.norm(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _load_state_dict(self,
pretrained,
strict: bool = False):
_load_checkpoint(self,
pretrained,
strict=strict)
@register_pip_model
@register_model
def mamba_vision_T(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar")
depths = kwargs.pop("depths", [1, 3, 8, 4])
num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
window_size = kwargs.pop("window_size", [8, 8, 14, 7])
dim = kwargs.pop("dim", 80)
in_dim = kwargs.pop("in_dim", 32)
mlp_ratio = kwargs.pop("mlp_ratio", 4)
resolution = kwargs.pop("resolution", 224)
drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[1, 3, 8, 4],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=80,
in_dim=32,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.2,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_pip_model
@register_model
def mamba_vision_T2(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar")
depths = kwargs.pop("depths", [1, 3, 11, 4])
num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
window_size = kwargs.pop("window_size", [8, 8, 14, 7])
dim = kwargs.pop("dim", 80)
in_dim = kwargs.pop("in_dim", 32)
mlp_ratio = kwargs.pop("mlp_ratio", 4)
resolution = kwargs.pop("resolution", 224)
drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[1, 3, 11, 4],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=80,
in_dim=32,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.2,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_pip_model
@register_model
def mamba_vision_S(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar")
depths = kwargs.pop("depths", [3, 3, 7, 5])
num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
window_size = kwargs.pop("window_size", [8, 8, 14, 7])
dim = kwargs.pop("dim", 96)
in_dim = kwargs.pop("in_dim", 64)
mlp_ratio = kwargs.pop("mlp_ratio", 4)
resolution = kwargs.pop("resolution", 224)
drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 7, 5],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=96,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.2,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_pip_model
@register_model
def mamba_vision_B(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar")
depths = kwargs.pop("depths", [3, 3, 10, 5])
num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
window_size = kwargs.pop("window_size", [8, 8, 14, 7])
dim = kwargs.pop("dim", 128)
in_dim = kwargs.pop("in_dim", 64)
mlp_ratio = kwargs.pop("mlp_ratio", 4)
resolution = kwargs.pop("resolution", 224)
drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
layer_scale = kwargs.pop("layer_scale", 1e-5)
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 10, 5],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=128,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_pip_model
@register_model
def mamba_vision_L(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar")
depths = kwargs.pop("depths", [3, 3, 10, 5])
num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
window_size = kwargs.pop("window_size", [8, 8, 14, 7])
dim = kwargs.pop("dim", 196)
in_dim = kwargs.pop("in_dim", 64)
mlp_ratio = kwargs.pop("mlp_ratio", 4)
resolution = kwargs.pop("resolution", 224)
drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
layer_scale = kwargs.pop("layer_scale", 1e-5)
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 10, 5],
num_heads=[4, 8, 16, 32],
window_size=[8, 8, 14, 7],
dim=196,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_pip_model
@register_model
def mamba_vision_L2(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar")
depths = kwargs.pop("depths", [3, 3, 12, 5])
num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
window_size = kwargs.pop("window_size", [8, 8, 14, 7])
dim = kwargs.pop("dim", 196)
in_dim = kwargs.pop("in_dim", 64)
mlp_ratio = kwargs.pop("mlp_ratio", 4)
resolution = kwargs.pop("resolution", 224)
drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
layer_scale = kwargs.pop("layer_scale", 1e-5)
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 12, 5],
num_heads=[4, 8, 16, 32],
window_size=[8, 8, 14, 7],
dim=196,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
"""
Scripts to register and load model, adopted from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_registry.py
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_factory.py
Hacked together by / Copyright 2023 Ross Wightman
"""
import torch
import os
from collections import OrderedDict
from copy import deepcopy
from typing import Any
import sys
import re
import fnmatch
from collections import defaultdict
from copy import deepcopy
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model default_cfgs
def register_pip_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if module:
all_models = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_default_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))
def is_model(model_name):
""" Check if a model name exists
"""
return model_name in _model_entrypoints
def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]
def list_modules():
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return list(sorted(modules))
def is_model_in_modules(model_name, module_names):
"""Check if a model exists within a subset of modules
Args:
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names)
def has_model_default_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key.
"""
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
return True
return False
def is_model_default_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist.
"""
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
return True
return False
def get_model_default_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist.
"""
if model_name in _model_default_cfgs:
return _model_default_cfgs[model_name].get(cfg_key, None)
else:
return None
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
print("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else:
print("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
model.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
def create_model(
model_name,
pretrained=False,
checkpoint_path='',
**kwargs):
create_fn = model_entrypoint(model_name)
model = create_fn(pretrained=pretrained, **kwargs)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)
return model
\ No newline at end of file
#!/bin/bash
DATA_PATH="/ImageNet/train"
MODEL=mamba_vision_T
BS=2
EXP=Test
LR=8e-4
WD=0.05
WR_LR=1e-6
DR=0.38
MESA=0.25
torchrun --nproc_per_node=2 --master_port=29501 train.py --mesa ${MESA} --input-size 3 224 224 --crop-pct=0.875 \
--data_dir=$DATA_PATH --model $MODEL --amp --weight-decay ${WD} --drop-path ${DR} --batch-size $BS --tag $EXP --lr $LR --warmup-lr $WR_LR
#!/bin/bash
DATA_PATH="/ImageNet/val"
BS=128
checkpoint='/model_weights/mambavision_tiny_1k.pth.tar'
python validate.py --model mamba_vision_T --checkpoint=$checkpoint --data_dir=$DATA_PATH --batch-size $BS --input-size 3 224 224 \
--num-gpu 2
\ No newline at end of file
from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler
from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler
""" Cosine Scheduler
Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class CosineLRScheduler(Scheduler):
"""
Cosine decay with restarts.
This is described in the paper https://arxiv.org/abs/1608.03983.
Inspiration from
https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
k = self.k_decay
if i < self.cycle_limit:
lrs = [
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
""" MultiStep LR Scheduler
Basic multi step LR schedule with warmup, noise.
"""
import torch
import bisect
from timm.scheduler.scheduler import Scheduler
from typing import List
class MultiStepLRScheduler(Scheduler):
"""
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
decay_t: List[int],
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def get_curr_decay_steps(self, t):
# find where in the array t goes,
# assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t+1)
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
""" Plateau Scheduler
Adapts PyTorch plateau scheduler and allows application of noise, warmup.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from .scheduler import Scheduler
class PlateauLRScheduler(Scheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self,
optimizer,
decay_rate=0.1,
patience_t=10,
verbose=True,
threshold=1e-4,
cooldown_t=0,
warmup_t=0,
warmup_lr_init=0,
lr_min=0,
mode='max',
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize=True,
):
super().__init__(
optimizer,
'lr',
noise_range_t=noise_range_t,
noise_type=noise_type,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
patience=patience_t,
factor=decay_rate,
verbose=verbose,
threshold=threshold,
cooldown=cooldown_t,
mode=mode,
min_lr=lr_min
)
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
self.restore_lr = None
def state_dict(self):
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
# override the base class step fn completely
def step(self, epoch, metric=None):
if epoch <= self.warmup_t:
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
super().update_groups(lrs)
else:
if self.restore_lr is not None:
# restore actual LR from before our last noise perturbation before stepping base
for i, param_group in enumerate(self.optimizer.param_groups):
param_group['lr'] = self.restore_lr[i]
self.restore_lr = None
self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self._is_apply_noise(epoch):
self._apply_noise(epoch)
def _apply_noise(self, epoch):
noise = self._calculate_noise(epoch)
# apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler
restore_lr = []
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group['lr'])
restore_lr.append(old_lr)
new_lr = old_lr + old_lr * noise
param_group['lr'] = new_lr
self.restore_lr = restore_lr
""" Polynomial Scheduler
Polynomial LR schedule with warmup, noise.
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
import logging
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class PolyLRScheduler(Scheduler):
""" Polynomial LR Scheduler w/ warmup, noise, and k-decay
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
power: float = 0.5,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.power = power
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
k = self.k_decay
if i < self.cycle_limit:
lrs = [
self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
from typing import Dict, Any
import torch
class Scheduler:
""" Parameter Scheduler Base Class
A scheduler base class that can be used to schedule any optimizer parameter groups.
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
The schedulers built on this should try to remain as stateless as possible (for simplicity).
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
Based on ideas from:
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize: bool = True) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}"
if initialize:
for i, group in enumerate(self.optimizer.param_groups):
if param_group_field not in group:
raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
group.setdefault(self._initial_param_group_field, group[param_group_field])
else:
for i, group in enumerate(self.optimizer.param_groups):
if self._initial_param_group_field not in group:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.metric = None # any point to having this for all?
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.update_groups(self.base_values)
def state_dict(self) -> Dict[str, Any]:
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
def get_epoch_values(self, epoch: int):
return None
def get_update_values(self, num_updates: int):
return None
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
values = self.get_epoch_values(epoch)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
values = self.get_update_values(num_updates)
if values is not None:
values = self._add_noise(values, num_updates)
self.update_groups(values)
def update_groups(self, values):
if not isinstance(values, (list, tuple)):
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
if 'lr_scale' in param_group:
param_group[self.param_group_field] = value * param_group['lr_scale']
else:
param_group[self.param_group_field] = value
def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
return apply_noise
def _calculate_noise(self, t) -> float:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
return noise
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise
""" Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman
"""
from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler
from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
n_iter = args.data_len // (args.batch_size * args.world_size)
tot_iter = num_epochs * n_iter
warmup_iters = args.warmup_epochs * n_iter
if getattr(args, 'lr_noise', None) is not None:
lr_noise = getattr(args, 'lr_noise')
if isinstance(lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in lr_noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = lr_noise * num_epochs
else:
noise_range = None
noise_args = dict(
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
cycle_args = dict(
cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
)
lr_scheduler = None
if args.sched == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=tot_iter,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=warmup_iters,
k_decay=getattr(args, 'lr_k_decay', 1.0),
t_in_epochs=args.lr_ep,
**cycle_args,
**noise_args,
)
cycle_length = lr_scheduler.get_cycle_length() // n_iter
num_epochs = cycle_length + args.cooldown_epochs
elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
t_in_epochs=True,
**cycle_args,
**noise_args,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
**noise_args,
)
elif args.sched == 'multistep':
lr_scheduler = MultiStepLRScheduler(
optimizer,
decay_t=args.decay_milestones,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
**noise_args,
)
elif args.sched == 'plateau':
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=0,
**noise_args,
)
elif args.sched == 'poly':
lr_scheduler = PolyLRScheduler(
optimizer,
power=args.decay_rate, # overloading 'decay_rate' as polynomial power
t_initial=num_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
k_decay=getattr(args, 'lr_k_decay', 1.0),
**cycle_args,
**noise_args,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
return lr_scheduler, num_epochs
""" Step Scheduler
Basic step LR schedule with warmup, noise.
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import torch
from .scheduler import Scheduler
class StepLRScheduler(Scheduler):
"""
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
decay_t: float,
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
""" TanH Scheduler
TanH schedule with warmup, cycle/restarts, noise.
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class TanhLRScheduler(Scheduler):
"""
Hyberbolic-Tangent decay with restarts.
This is described in the paper https://arxiv.org/abs/1806.01593
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lb: float = -7.,
ub: float = 3.,
lr_min: float = 0.,
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
assert lb < ub
assert cycle_limit >= 0
assert warmup_t >= 0
assert warmup_lr_init >= 0
self.lb = lb
self.ub = ub
self.t_initial = t_initial
self.lr_min = lr_min
self.cycle_mul = cycle_mul
self.cycle_decay = cycle_decay
self.cycle_limit = cycle_limit
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
if self.warmup_prefix:
t = t - self.warmup_t
if self.cycle_mul != 1:
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
t_i = self.cycle_mul ** i * self.t_initial
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
else:
i = t // self.t_initial
t_i = self.t_initial
t_curr = t - (self.t_initial * i)
if i < self.cycle_limit:
gamma = self.cycle_decay ** i
lr_max_values = [v * gamma for v in self.base_values]
tr = t_curr / t_i
lrs = [
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
for lr_max in lr_max_values
]
else:
lrs = [self.lr_min for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
import torch
from tensorboardX import SummaryWriter
class TensorboardLogger(object):
def __init__(self, log_dir):
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
def flush(self):
self.writer.flush()
\ No newline at end of file
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examplesf
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2023 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import time
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import ImageDataset, create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm import utils
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
LabelSmoothingCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import *
from timm.utils import ApexScaler, NativeScaler
from scheduler.scheduler_factory import create_scheduler
import shutil
from utils.datasets import imagenet_lmdb_dataset
from tensorboard import TensorboardLogger
from models.mamba_vision import *
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside of the dataset group because it is positional.
parser.add_argument('--data_dir', metavar='DIR',
help='path to dataset')
group.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--tag', default='exp', type=str, metavar='TAG')
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='gc_vit_tiny', type=str, metavar='MODEL',
help='Name of model to train (default: "gc_vit_tiny"')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--loadcheckpoint', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=0.875, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
group.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8, use opt default)')
group.add_argument('--opt-betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
group.add_argument('--clip-grad', type=float, default=5.0, metavar='NORM',
help='Clip gradient norm (default: 5.0, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr-ep', action='store_true', default=False,
help='using the epoch-based scheduler')
group.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=1.0, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
group.add_argument('--min-lr', type=float, default=5e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (5e-6)')
group.add_argument('--epochs', type=int, default=310, metavar='N',
help='number of epochs to train (default: 310)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=100, metavar='N',
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=20, metavar='N',
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default="rand-m9-mstd0.5-inc1", metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
group.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop-rate', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
group.add_argument('--attn-drop-rate', type=float, default=0.0, metavar='PCT',
help='Drop of the attention, gaussian std')
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=1, metavar='N',
help='number of checkpoints to keep (default: 3)')
group.add_argument('-j', '--workers', type=int, default=8, metavar='N',
help='how many training processes to use (default: 8)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
group.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
group.add_argument('--log_dir', default='./log_dir/', type=str,
help='where to store tensorboard')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument("--data_len", default=1281167, type=int,help='size of the dataset')
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
group.add_argument('--validate_only', action='store_true', default=False,
help='run model validation only')
group.add_argument('--no_saver', action='store_true', default=False,
help='Save checkpoints')
group.add_argument('--ampere_sparsity', action='store_true', default=False,
help='Save checkpoints')
group.add_argument('--lmdb_dataset', action='store_true', default=False,
help='use lmdb dataset')
group.add_argument('--bfloat', action='store_true', default=False,
help='use bfloat datatype')
group.add_argument('--mesa', type=float, default=0.0,
help='use memory efficient sharpness optimization, enabled if >0.0')
group.add_argument('--mesa-start-ratio', type=float, default=0.25,
help='when to start MESA, ratio to total training time, def 0.25')
kl_loss = torch.nn.KLDivLoss(reduction='batchmean').cuda()
def kdloss(y, teacher_scores):
T = 3
p = torch.nn.functional.log_softmax(y/T, dim=1)
q = torch.nn.functional.softmax(teacher_scores/T, dim=1)
l_kl = 50.0*kl_loss(p, q)
return l_kl
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.local_rank = int(os.environ['LOCAL_RANK'])
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
# torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d. Local rank %d'
% (args.rank, args.world_size, args.local_rank))
else:
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
utils.random_seed(args.seed, args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
attn_drop_rate=args.attn_drop_rate,
drop_rate=args.drop_rate,
drop_path_rate=args.drop_path)
if args.bfloat:
args.dtype = torch.bfloat16
else:
args.dtype = torch.float16
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
print("filter_bias_and_bn")
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if not os.path.isfile(args.resume):
args.resume = ""
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = utils.ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.loadcheckpoint:
_logger.info(f"Loading checkpoint {args.loadcheckpoint}, checking for existing parameters if their shape match")
new_model_weights = torch.load(args.loadcheckpoint)["state_dict"]
current_model = model.state_dict()
new_state_dict = OrderedDict()
for k in current_model.keys():
if k in new_model_weights.keys():
if new_model_weights[k].size() == current_model[k].size():
print(f"loading weights {k} {new_model_weights[k].size()}")
new_state_dict[k] = new_model_weights[k]
model.load_state_dict(new_state_dict, strict=False)
model_ema.module.load_state_dict(new_state_dict, strict=False)
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
if args.lmdb_dataset:
train_dir = os.path.join(args.data_dir, 'train')
if 'lmdb' in args.data_dir:
dataset_train = imagenet_lmdb_dataset(
train_dir, transform=None)
else:
if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = ImageDataset(train_dir)
eval_dir = os.path.join(args.data_dir, 'val')
if 'lmdb' in args.data_dir:
dataset_eval = imagenet_lmdb_dataset(
eval_dir, transform=None)
else:
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data_dir, 'validation')
if not os.path.isdir(eval_dir):
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = ImageDataset(eval_dir)
else:
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = None
if args.rank == 0:
log_dir = args.log_dir + '_' + args.tag
os.makedirs(log_dir, exist_ok=True)
log_writer = TensorboardLogger(log_dir=log_dir)
else:
log_writer = None
if args.rank == 0:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
args.experiment = exp_name
output_dir = utils.get_outdir(args.output if args.output else f'../output/train/{args.tag}/', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=1)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if 1: #args.copy_code
# copy .py files
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk('.') for f in filenames if
os.path.splitext(f)[1] == '.py']
for f in files:
if "/code_copy/" in f: continue
new_path = output_dir + "/code_copy/" + f
os.makedirs(os.path.dirname(new_path), exist_ok=True)
shutil.copyfile(f, new_path)
if args.validate_only:
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
print(eval_metrics)
exit()
test_acc_track=[]
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
saver = saver if not args.no_saver else None
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
# print(train_metrics.keys())
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
if log_writer is not None:
log_writer.update(test_acc1=eval_metrics['top1'], head="perf", step=epoch)
log_writer.update(test_loss=eval_metrics['loss'], head="perf", step=epoch)
log_writer.update(train_loss=train_metrics['loss'], head="perf", step=epoch)
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
log_writer.update(lr=lr, head="perf", step=epoch)
test_acc_track.append(eval_metrics['top1'])
stopif = True if len(test_acc_track)>1 and test_acc_track[-1]<1.0 else False
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if log_writer is not None:
log_writer.update(test_acc1_ema=eval_metrics['top1'], head="perf", step=epoch)
log_writer.update(test_loss_ema=eval_metrics['loss'], head="perf", step=epoch)
if lr_scheduler is not None and args.lr_ep:
# step LR for next epoch
lr_scheduler.step(epoch + 1, None if eval_metrics is None else eval_metrics[eval_metric])
if output_dir is not None:
utils.update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
if saver is not None:
# save proper checkpoint with eval metric
save_metric = None if eval_metrics is None else eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
if not np.isfinite(eval_metrics['loss']) or stopif:
# if got None then exit
if args.local_rank == 0:
_logger.info("Nan in loss, exit")
_logger.error("Nan in loss, exit")
input, target = next(iter(loader_eval))
input = input.cuda()
with torch.autograd.detect_anomaly():
with amp_autocast(dtype=args.dtype):
output = model(input)
print(output)
exit(1)
return 0
except KeyboardInterrupt:
pass
except SystemExit:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
num_iters = len(loader)
display_first = True
if args.ampere_sparsity:
model.enforce_mask()
for batch_idx, (input, target) in enumerate(loader):
if lr_scheduler is not None and not args.lr_ep:
lr_scheduler.step_update(num_updates=(epoch * num_iters) + batch_idx + 1)
if (batch_idx == 0) or (batch_idx % 50 == 0):
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast(dtype=args.dtype):
output = model(input)
loss = loss_fn(output, target)
if args.mesa>0.0:
if epoch/args.epochs > args.mesa_start_ratio:
with torch.no_grad():
ema_output = model_ema.module(input).data.detach()
kd = kdloss(output, ema_output)
loss += args.mesa * kd
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
if args.ampere_sparsity:
model.enforce_mask(grad=True)
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None and args.lr_ep:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()
model.eval()
if args.ampere_sparsity:
model.enforce_mask()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast(dtype=args.dtype):
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
acc1 = utils.reduce_tensor(acc1, args.world_size)
acc5 = utils.reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
if __name__ == '__main__':
main()
#!/bin/bash
DATA_PATH="/ImageNet/train"
MODEL=mamba_vision_T
BS=2
EXP=Test
LR=8e-4
WD=0.05
WR_LR=1e-6
DR=0.38
MESA=0.25
python train.py --mesa ${MESA} --input-size 3 224 224 --crop-pct=0.875 \
--data_dir=$DATA_PATH --model $MODEL --amp --weight-decay ${WD} --drop-path ${DR} --batch-size $BS --tag $EXP --lr $LR --warmup-lr $WR_LR
"""Code for getting the data loaders."""
import numpy as np
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch._utils import _accumulate
from timm.data import IterableImageDataset, ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
def get_loaders(args, mode='eval', dataset=None):
"""Get data loaders for required dataset."""
if dataset is None:
dataset = args.dataset
if dataset == 'imagenet':
return get_imagenet_loader(args, mode)
else:
if mode == 'search':
return get_loaders_search(args)
elif mode == 'eval':
return get_loaders_eval(dataset, args)
class Subset_imagenet(torch.utils.data.Dataset):
r"""
Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset , indices) -> None:
self.dataset = dataset
self.indices = indices
self.transform = None
def __getitem__(self, idx):
img, target = self.dataset[self.indices[idx]]
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.indices)
def get_loaders_eval(dataset, args):
"""Get train and valid loaders for cifar10/tiny imagenet."""
if dataset == 'cifar10':
num_classes = 10
train_transform, valid_transform = _data_transforms_cifar10(args)
train_data = dset.CIFAR10(
root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(
root=args.data, train=False, download=True, transform=valid_transform)
elif dataset == 'cifar100':
num_classes = 100
train_transform, valid_transform = _data_transforms_cifar10(args)
train_data = dset.CIFAR100(
root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(
root=args.data, train=False, download=True, transform=valid_transform)
train_sampler, valid_sampler = None, None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_data)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_data)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler, pin_memory=True, num_workers=16)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False,
sampler=valid_sampler, pin_memory=True, num_workers=16)
return train_queue, valid_queue, num_classes
def get_loaders_search(args):
"""Get train and valid loaders for cifar10/tiny imagenet."""
if args.dataset == 'cifar10':
num_classes = 10
train_transform, _ = _data_transforms_cifar10(args)
train_data = dset.CIFAR10(
root=args.data, train=True, download=True, transform=train_transform)
elif args.dataset == 'cifar100':
num_classes = 100
train_transform, _ = _data_transforms_cifar10(args)
train_data = dset.CIFAR100(
root=args.data, train=True, download=True, transform=train_transform)
num_train = len(train_data)
print('Found %d samples' % (num_train))
sub_num_train = int(np.floor(args.train_portion * num_train))
sub_num_valid = num_train - sub_num_train
sub_train_data, sub_valid_data = my_random_split(
train_data, [sub_num_train, sub_num_valid], seed=0)
print('Train: Split into %d samples' % (len(sub_train_data)))
print('Valid: Split into %d samples' % (len(sub_valid_data)))
train_sampler, valid_sampler = None, None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
sub_train_data)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
sub_valid_data)
train_queue = torch.utils.data.DataLoader(
sub_train_data, batch_size=args.batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler, pin_memory=True, num_workers=16, drop_last=True)
valid_queue = torch.utils.data.DataLoader(
sub_valid_data, batch_size=args.batch_size,
shuffle=(valid_sampler is None),
sampler=valid_sampler, pin_memory=True, num_workers=16, drop_last=True)
return train_queue, valid_queue, num_classes
################################################################################
# ImageNet
################################################################################
def get_imagenet_loader(args, mode='eval', testdir = ""):
"""Get train/val for imagenet."""
traindir = os.path.join(args.data, 'train')
validdir = os.path.join(args.data, 'val')
print("verify testing path")
if len(testdir) < 2:
testdir = os.path.join("../ImageNetV2/", 'test')
# print("\n\n\n loading imagenet v2 \n\n\n")
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
downscale = 1
val_transform = transforms.Compose([
transforms.Resize(args.resize//downscale),
transforms.CenterCrop(args.resolution//downscale),
transforms.ToTensor(),
normalize,
])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(args.resolution//downscale),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
if mode == 'eval':
if 'lmdb' in args.data:
train_data = imagenet_lmdb_dataset(
traindir, transform=train_transform)
valid_data = imagenet_lmdb_dataset(
validdir, transform=val_transform)
else:
train_data = dset.ImageFolder(traindir, transform=train_transform)
valid_data = dset.ImageFolder(validdir, transform=val_transform)
train_sampler, valid_sampler = None, None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_data)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_data)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
shuffle=(train_sampler is None),
pin_memory=True, num_workers=16, sampler=train_sampler, drop_last=True)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=(valid_sampler is None),
pin_memory=True, num_workers=16, sampler=valid_sampler)
elif mode == 'search':
if 'lmdb' in args.data:
train_data = imagenet_lmdb_dataset(
traindir, transform=val_transform)
else:
train_data = dset.ImageFolder(traindir, val_transform)
num_train = len(train_data)
print('Found %d samples' % (num_train))
sub_num_train = int(np.floor(args.train_portion * num_train))
sub_num_valid = num_train - sub_num_train
sub_train_data, sub_valid_data = my_random_split(
train_data, [sub_num_train, sub_num_valid], seed=0)
print('Train: Split into %d samples' % (len(sub_train_data)))
print('Valid: Split into %d samples' % (len(sub_valid_data)))
train_sampler, valid_sampler = None, None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
sub_train_data)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
sub_valid_data)
train_queue = torch.utils.data.DataLoader(
sub_train_data, batch_size=args.batch_size,
sampler=train_sampler, shuffle=(train_sampler is None),
pin_memory=True, num_workers=16, drop_last=True)
valid_queue = torch.utils.data.DataLoader(
sub_valid_data, batch_size=args.batch_size,
sampler=valid_sampler, shuffle=(valid_sampler is None),
pin_memory=True, num_workers=16, drop_last=False)
elif mode == 'timm':
if 'lmdb' in args.data:
train_data = imagenet_lmdb_dataset(
traindir, transform=None)
valid_data = imagenet_lmdb_dataset(
traindir, transform=val_transform)
else:
train_data = ImageDataset(traindir)
valid_data = dset.ImageFolder(traindir, transform=val_transform)
train_interpolation = 'bicubic'
train_queue = create_loader(
train_data,
input_size=args.resize // downscale,
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
no_aug=False,
re_prob=0.2,
re_mode="pixel",
re_count=1,
re_split=False,
scale=[0.08, 1.0],
ratio=[0.75, 1.3333333333333333],
hflip=0.5,
vflip=0.0,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5",
num_aug_splits=0,
interpolation=train_interpolation,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
num_workers=16,
distributed=args.distributed,
collate_fn=None,
pin_memory=False,
use_multi_epochs_loader=False
)
num_train = len(valid_data)
print('Found %d samples' % (num_train))
sub_num_train = int(np.floor(args.train_portion * num_train))
sub_num_valid = num_train - sub_num_train
_, sub_valid_data = my_random_split(
valid_data, [sub_num_train, sub_num_valid], seed=0)
print('Valid: Split into %d samples' % (len(sub_valid_data)))
train_sampler, valid_sampler = None, None
if args.distributed:
valid_sampler = torch.utils.data.distributed.DistributedSampler(
sub_valid_data)
valid_queue = torch.utils.data.DataLoader(
sub_valid_data, batch_size=args.batch_size,
shuffle=(valid_sampler is None),
sampler=valid_sampler, pin_memory=True, num_workers=16, drop_last=False)
elif mode == 'timm2':
if 'lmdb' in args.data:
train_data = imagenet_lmdb_dataset(
traindir, transform=None)
valid_data = imagenet_lmdb_dataset(
traindir, transform=val_transform)
else:
train_data = ImageDataset(traindir)
valid_data = ImageDataset(testdir)
train_interpolation = "bicubic"
train_queue = create_loader(
train_data,
input_size=args.resize // downscale,
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
no_aug=False,
re_prob=0.2,
re_mode="pixel",
re_count=1,
re_split=False,
scale=[0.08, 1.0],
ratio=[0.75, 1.3333333333333333],
hflip=0.5,
vflip=0.0,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5",
num_aug_splits=0,
# interpolation=train_interpolation,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
num_workers=16,
distributed=args.distributed,
collate_fn=None,
pin_memory=False,
use_multi_epochs_loader=False
)
valid_queue = create_loader(
valid_data,
input_size=args.resize // downscale,
batch_size=args.batch_size,
is_training=False,
use_prefetcher=True,
interpolation=train_interpolation,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
num_workers=16,
distributed=args.distributed,
crop_pct=0.875,
color_jitter=0.4,
pin_memory=False,
)
elif mode == 'timm3':
# with test set from ImageNetV2 test split
if 'lmdb' in args.data:
train_data = imagenet_lmdb_dataset(
traindir, transform=None)
valid_data = imagenet_lmdb_dataset(
traindir, transform=val_transform)
else:
train_data = ImageDataset(traindir)
valid_data = ImageDataset(testdir)
# valid_data = ImageDataset(traindir)
train_interpolation = 'bicubic'
train_queue = create_loader(
train_data,
input_size=args.resize // downscale,
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
no_aug=False,
re_prob=0.2,
re_mode="pixel",
re_count=1,
re_split=False,
scale=[0.08, 1.0],
ratio=[0.75, 1.3333333333333333],
hflip=0.5,
vflip=0.0,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5",
num_aug_splits=0,
interpolation=train_interpolation,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
# num_workers=16,
num_workers=8,
distributed=args.distributed,
collate_fn=None,
pin_memory=False,
use_multi_epochs_loader=False
)
valid_queue = create_loader(
valid_data,
input_size=args.resize // downscale,
batch_size=args.batch_size * 4,
is_training=True,
use_prefetcher=True,
interpolation=train_interpolation,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
# num_workers=16,
num_workers=8,
distributed=args.distributed,
crop_pct=0.875,
color_jitter=0.0,
pin_memory=False,
)
return train_queue, valid_queue, 1000
################################################################################
def my_random_split(dataset, lengths, seed=0):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!")
g = torch.Generator()
g.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=g)
return [Subset_imagenet(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
################################################################################
def my_random_split_perc(dataset, percent_train, seed=0):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
percent_train (float): portion of the dataset to be used for training
"""
num_train = len(dataset)
print('Found %d samples' % (num_train))
sub_num_train = int(np.floor(percent_train * num_train))
sub_num_valid = num_train - sub_num_train
dataset_train, dataset_validation = my_random_split(dataset, [sub_num_train, sub_num_valid], seed=seed)
print('Train: Split into %d samples' % (len(dataset)))
return [dataset_train, dataset_validation]
################################################################################
# ImageNet - LMDB
################################################################################
import io
import os
try:
import lmdb
except:
pass
import torch
from torchvision import datasets
from PIL import Image
def lmdb_loader(path, lmdb_data):
# In-memory binary streams
with lmdb_data.begin(write=False, buffers=True) as txn:
bytedata = txn.get(path.encode('ascii'))
img = Image.open(io.BytesIO(bytedata))
return img.convert('RGB')
def imagenet_lmdb_dataset(
root, transform=None, target_transform=None,
loader=lmdb_loader):
if root.endswith('/'):
root = root[:-1]
pt_path = os.path.join(
root + '_faster_imagefolder.lmdb.pt')
lmdb_path = os.path.join(
root + '_faster_imagefolder.lmdb')
if os.path.isfile(pt_path) and os.path.isdir(lmdb_path):
print('Loading pt {} and lmdb {}'.format(pt_path, lmdb_path))
data_set = torch.load(pt_path)
else:
data_set = datasets.ImageFolder(
root, None, None, None)
torch.save(data_set, pt_path, pickle_protocol=4)
print('Saving pt to {}'.format(pt_path))
print('Building lmdb to {}'.format(lmdb_path))
env = lmdb.open(lmdb_path, map_size=1e12)
with env.begin(write=True) as txn:
for path, class_index in data_set.imgs:
with open(path, 'rb') as f:
data = f.read()
txn.put(path.encode('ascii'), data)
data_set.lmdb_data = lmdb.open(
lmdb_path, readonly=True, max_readers=1, lock=False, readahead=False,
meminit=False)
# reset transform and target_transform
data_set.samples = data_set.imgs
data_set.transform = transform
data_set.target_transform = target_transform
data_set.loader = lambda path: loader(path, data_set.lmdb_data)
return data_set
if __name__ == '__main__':
import torch.distributed as dist
import argparse
import matplotlib
matplotlib.use('tkagg')
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser('Cell search')
args = parser.parse_args()
args.data = '/data/datasets/imagenet_lmdb/'
args.train_portion = 0.9
args.batch_size = 48
args.seed = 1
args.local_rank = 0
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '6020'
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://', rank=0, world_size=1)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
q1, q2, _ = get_imagenet_loader(args, mode='search')
iterator = iter(q1)
input_search, target_search = next(iterator)
print(len(q1), len(q2))
ind = 0
for batch, target in q1:
"""
img = batch[0].numpy().transpose(1, 2, 0)[:, :, 0]
plt.imshow(img)
plt.show()
plt.pause(1.)
"""
if ind % 100 == 0:
print(ind)
ind += 1
t1, t2, _ = get_imagenet_loader(args, mode='eval')
print(len(t1), len(t2))
for batch, target in t1:
img = batch[0].numpy().transpose(1, 2, 0)[:, :, 0]
plt.imshow(img)
plt.show()
plt.pause(1.)
break
#!/usr/bin/env python3
""" ImageNet Validation Script
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import csv
import glob
import json
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from functools import partial
import torch
import torch.nn as nn
import torch.nn.parallel
from models.mamba_vision import *
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry, ParseKwargs
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
has_compile = hasattr(torch, 'compile')
_logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset (*deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR',
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--crop-mode', default=None, type=str,
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
help='enable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--results-format', default='csv', type=str,
help='Format for results file one of (csv, json) (default: csv).')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
use_amp = 'apex'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16')
use_amp = 'native'
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=in_chans,
global_pool=args.gp,
scriptable=args.torchscript,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True,
)
test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device)
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
model = torch.jit.script(model)
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
if use_amp == 'apex':
model = amp.initialize(model, opt_level='O1')
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
dataset = create_dataset(
root=root_dir,
name=args.dataset,
split=args.split,
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = [int(line.rstrip()) for line in f]
else:
valid_labels = None
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
pin_memory=args.pin_mem,
device=device,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.to(device)
input = input.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
with amp_autocast():
output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses,
top1=top1,
top5=top5
)
)
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
results = OrderedDict(
model=args.model,
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'],
)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def _try_run(args, initial_batch_size):
batch_size = initial_batch_size
results = OrderedDict()
error_str = 'Unknown'
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
if torch.cuda.is_available() and 'cuda' in args.device:
torch.cuda.empty_cache()
results = validate(args)
return results
except RuntimeError as e:
error_str = str(e)
_logger.error(f'"{error_str}" while running validation.')
if not check_batch_size_retry(error_str):
break
batch_size = decay_batch_step(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
_logger.error(f'{args.model} failed to validate ({error_str}).')
return results
_NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer']
def main():
setup_default_logging()
args = parser.parse_args()
model_cfgs = []
model_names = []
if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(
pretrained=True,
exclude_filters=_NON_IN1K_FILTERS,
)
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
model_names = list_models(
args.model,
pretrained=True,
)
model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model):
with open(args.model) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs):
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
initial_batch_size = args.batch_size
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
r = _try_run(args, initial_batch_size)
if 'error' in r:
continue
if args.checkpoint:
r['checkpoint'] = args.checkpoint
results.append(r)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
else:
if args.retry:
results = _try_run(args, args.batch_size)
else:
results = validate(args)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}')
def write_results(results_file, results, format='csv'):
with open(results_file, mode='w') as cf:
if format == 'json':
json.dump(results, cf, indent=4)
else:
if not isinstance(results, (list, tuple)):
results = [results]
if not results:
return
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()
#!/bin/bash
DATA_PATH="/ImageNet/val"
BS=128
checkpoint='/model_weights/mambavision_tiny_1k.pth.tar'
python validate.py --model mamba_vision_T --checkpoint=$checkpoint --data_dir=$DATA_PATH --batch-size $BS --input-size 3 224 224
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment