Commit 459ecd48 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add transformers model

parents
Pipeline #317 failed with stages
in 0 seconds
# from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# https://github.com/berniwal/swin-transformer-pytorch
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat
class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement
def forward(self, x):
return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
def create_mask(window_size, displacement, upper_lower, left_right):
mask = torch.zeros(window_size ** 2, window_size ** 2)
if upper_lower:
mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')
if left_right:
mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
mask[:, -displacement:, :, :-displacement] = float('-inf')
mask[:, :-displacement, :, -displacement:] = float('-inf')
mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
return mask
def get_relative_distances(window_size):
indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
distances = indices[None, :, :] - indices[:, None, :]
return distances
class WindowAttention(nn.Module):
def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
inner_dim = head_dim * heads
self.heads = heads
self.scale = head_dim ** -0.5
self.window_size = window_size
self.relative_pos_embedding = relative_pos_embedding
self.shifted = shifted
if self.shifted:
displacement = window_size // 2
self.cyclic_shift = CyclicShift(-displacement)
self.cyclic_back_shift = CyclicShift(displacement)
self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=True, left_right=False), requires_grad=False)
self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=False, left_right=True), requires_grad=False)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
if self.relative_pos_embedding:
self.relative_indices = get_relative_distances(window_size) + window_size - 1
self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
else:
self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
if self.shifted:
x = self.cyclic_shift(x)
b, n_h, n_w, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
nw_h = n_h // self.window_size
nw_w = n_w // self.window_size
q, k, v = map(
lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
h=h, w_h=self.window_size, w_w=self.window_size), qkv)
dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
if self.relative_pos_embedding:
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
else:
dots += self.pos_embedding
if self.shifted:
dots[:, :, -nw_w:] += self.upper_lower_mask
dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
attn = dots.softmax(dim=-1)
out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
out = self.to_out(out)
if self.shifted:
out = self.cyclic_back_shift(out)
return out
class SwinBlock(nn.Module):
def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
heads=heads,
head_dim=head_dim,
shifted=shifted,
window_size=window_size,
relative_pos_embedding=relative_pos_embedding)))
self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
def forward(self, x):
x = self.attention_block(x)
x = self.mlp_block(x)
return x
class PatchMerging(nn.Module):
def __init__(self, in_channels, out_channels, downscaling_factor):
super().__init__()
self.downscaling_factor = downscaling_factor
self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
def forward(self, x):
b, c, h, w = x.shape
new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
x = self.linear(x)
return x
class StageModule(nn.Module):
def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
relative_pos_embedding):
super().__init__()
assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'
self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
downscaling_factor=downscaling_factor)
self.layers = nn.ModuleList([])
for _ in range(layers // 2):
self.layers.append(nn.ModuleList([
SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
]))
def forward(self, x):
x = self.patch_partition(x)
for regular_block, shifted_block in self.layers:
x = regular_block(x)
x = shifted_block(x)
return x.permute(0, 3, 1, 2)
class SwinTransformer(nn.Module):
def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
super().__init__()
self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.mlp_head = nn.Sequential(
nn.LayerNorm(hidden_dim * 8),
nn.Linear(hidden_dim * 8, num_classes)
)
def forward(self, img):
x = self.stage1(img)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = x.mean(dim=[2, 3])
return self.mlp_head(x)
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
# -*- coding: utf-8 -*-
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
cfg = {
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
class VGG(nn.Module):
def __init__(self, vgg_name):
super(VGG, self).__init__()
self.features = self._make_layers(cfg[vgg_name])
self.classifier = nn.Linear(512, 10)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)]
in_channels = x
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)
def test():
net = VGG('VGG11')
x = torch.randn(2,3,32,32)
y = net(x)
print(y.size())
# test()
\ No newline at end of file
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_for_small_dataset.py
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class LSA(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()
mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(mask, mask_value)
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SPT(nn.Module):
def __init__(self, *, dim, patch_size, channels = 3):
super().__init__()
patch_dim = patch_size * patch_size * 5 * channels
self.to_patch_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim)
)
def forward(self, x):
shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
return self.to_patch_tokens(x_with_shifts)
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
# code in this file is adpated from rpmcruz/autoaugment
# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
import random
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
import numpy as np
import torch
from PIL import Image
def ShearX(img, v): # [-0.3, 0.3]
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
def ShearY(img, v): # [-0.3, 0.3]
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[0]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[1]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def Rotate(img, v): # [-30, 30]
assert -30 <= v <= 30
if random.random() > 0.5:
v = -v
return img.rotate(v)
def AutoContrast(img, _):
return PIL.ImageOps.autocontrast(img)
def Invert(img, _):
return PIL.ImageOps.invert(img)
def Equalize(img, _):
return PIL.ImageOps.equalize(img)
def Flip(img, _): # not from the paper
return PIL.ImageOps.mirror(img)
def Solarize(img, v): # [0, 256]
assert 0 <= v <= 256
return PIL.ImageOps.solarize(img, v)
def SolarizeAdd(img, addition=0, threshold=128):
img_np = np.array(img).astype(np.int)
img_np = img_np + addition
img_np = np.clip(img_np, 0, 255)
img_np = img_np.astype(np.uint8)
img = Image.fromarray(img_np)
return PIL.ImageOps.solarize(img, threshold)
def Posterize(img, v): # [4, 8]
v = int(v)
v = max(1, v)
return PIL.ImageOps.posterize(img, v)
def Contrast(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Color(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Color(img).enhance(v)
def Brightness(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Sharpness(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
assert 0.0 <= v <= 0.2
if v <= 0.:
return img
v = v * img.size[0]
return CutoutAbs(img, v)
def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
# assert 0 <= v <= 20
if v < 0:
return img
w, h = img.size
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = int(max(0, x0 - v / 2.))
y0 = int(max(0, y0 - v / 2.))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
xy = (x0, y0, x1, y1)
color = (125, 123, 114)
# color = (0, 0, 0)
img = img.copy()
PIL.ImageDraw.Draw(img).rectangle(xy, color)
return img
def SamplePairing(imgs): # [0, 0.4]
def f(img1, v):
i = np.random.choice(len(imgs))
img2 = PIL.Image.fromarray(imgs[i])
return PIL.Image.blend(img1, img2, v)
return f
def Identity(img, v):
return img
def augment_list(): # 16 oeprations and their ranges
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
# l = [
# (Identity, 0., 1.0),
# (ShearX, 0., 0.3), # 0
# (ShearY, 0., 0.3), # 1
# (TranslateX, 0., 0.33), # 2
# (TranslateY, 0., 0.33), # 3
# (Rotate, 0, 30), # 4
# (AutoContrast, 0, 1), # 5
# (Invert, 0, 1), # 6
# (Equalize, 0, 1), # 7
# (Solarize, 0, 110), # 8
# (Posterize, 4, 8), # 9
# # (Contrast, 0.1, 1.9), # 10
# (Color, 0.1, 1.9), # 11
# (Brightness, 0.1, 1.9), # 12
# (Sharpness, 0.1, 1.9), # 13
# # (Cutout, 0, 0.2), # 14
# # (SamplePairing(imgs), 0, 0.4), # 15
# ]
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
l = [
(AutoContrast, 0, 1),
(Equalize, 0, 1),
(Invert, 0, 1),
(Rotate, 0, 30),
(Posterize, 0, 4),
(Solarize, 0, 256),
(SolarizeAdd, 0, 110),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Brightness, 0.1, 1.9),
(Sharpness, 0.1, 1.9),
(ShearX, 0., 0.3),
(ShearY, 0., 0.3),
(CutoutAbs, 0, 40),
(TranslateXabs, 0., 100),
(TranslateYabs, 0., 100),
]
return l
class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = torch.Tensor(eigval)
self.eigvec = torch.Tensor(eigvec)
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone() \
.mul(alpha.view(1, 3).expand(3, 3)) \
.mul(self.eigval.view(1, 3).expand(3, 3)) \
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))
class CutoutDefault(object):
"""
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
"""
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
class RandAugment:
def __init__(self, n, m):
self.n = n
self.m = m # [0, 30]
self.augment_list = augment_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
val = (float(self.m) / 30) * float(maxval - minval) + minval
img = op(img, val)
return img
\ No newline at end of file
#!/usr/bin/env bash
data_dir="data"
if [ ! -x "$data_dir" ]; then
ln -s /data/swin/train/ $data_dir
fi
export HIP_VISIBLE_DEVICES=0,1,2,3
torchrun --nproc_per_node=4 train_cifar10.py --net swin --n_epochs 500 --noaug $1 $2
#!/usr/bin/env bash
export HIP_VISIBLE_DEVICES=0,1,2,3
torchrun --nproc_per_node=4 train_cifar10.py --net swin --n_epochs 500 --noaug --log_dir pid.txt 2>&1 | tee swin_dcu_`date +%Y%m%d%H%M%S`.log
# -*- coding: utf-8 -*-
'''
Train CIFAR10 with PyTorch and Vision Transformers!
written by @kentaroy47, @arutema47
'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import pandas as pd
#import csv
import time
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
dist.init_process_group(backend="nccl", init_method="env://")
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
from models import *
from utils import progress_bar
from randomaug import RandAugment
from models.vit import ViT
from models.convmixer import ConvMixer
def write_pid_file(pid_file_path):
'''Write pid file for watching the process later.
In each round of case, we will write the current pid in the same path.
'''
if os.path.exists(pid_file_path):
os.remove(pid_file_path)
file_d=open(pid_file_path,"w")
file_d.write("%s\n" % os.getpid())
file_d.close()
# parsers
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # resnets.. 1e-3, Vit..1e-4
parser.add_argument('--opt', default="adam")
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--noaug', action='store_true', help='disable use randomaug')
parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
#parser.add_argument('--nowandb', action='store_true', help='disable wandb')
parser.add_argument('--mixup', action='store_true', help='add mixup augumentations')
parser.add_argument('--net', default='vit')
parser.add_argument('--bs', default='512')
parser.add_argument('--size', default="32")
parser.add_argument('--n_epochs', type=int, default='200')
parser.add_argument('--patch', default='4', type=int, help="patch for ViT")
parser.add_argument('--dimhead', default="512", type=int)
parser.add_argument('--convkernel', default='8', type=int, help="parameter for convmixer")
parser.add_argument("--log_dir",
type=str,
default="/data/flagperf/training/result/",
help="Log directory in container.")
args = parser.parse_args()
if dist.get_rank() == 0:
write_pid_file(args.log_dir)
# take in args
#usewandb = ~args.nowandb
#if usewandb:
# import wandb
# watermark = "{}_lr{}".format(args.net, args.lr)
# wandb.init(project="cifar10-challange",
# name=watermark)
# wandb.config.update(args)
bs = int(args.bs)
bs = int(bs / world_size)
imsize = int(args.size)
use_amp = bool(~args.noamp)
aug = args.noaug
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device('cuda', local_rank)
torch.cuda.set_device(device)
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
global_steps = 0
target_acc = 84.49
final_acc = 0
num_trained_samples = 0
log_file_name = f'rank{local_rank}.out.log'
log_file = open(log_file_name, 'w')
# Data
#print('==> Preparing data..')
if args.net=="vit_timm":
size = 384
else:
size = imsize
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Resize(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Add RandAugment with N, M(hyperparameter)
if aug:
N = 2; M = 14;
transform_train.transforms.insert(0, RandAugment(N, M))
# Prepare dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=world_size, rank=local_rank, shuffle=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, sampler=train_sampler, num_workers=8)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
test_sampler = torch.utils.data.distributed.DistributedSampler(testset)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, sampler=test_sampler,shuffle=False, num_workers=8)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Model factory..
#print('==> Building model..')
# net = VGG('VGG19')
if args.net=='res18':
net = ResNet18()
elif args.net=='vgg':
net = VGG('VGG19')
elif args.net=='res34':
net = ResNet34()
elif args.net=='res50':
net = ResNet50()
elif args.net=='res101':
net = ResNet101()
elif args.net=="convmixer":
# from paper, accuracy >96%. you can tune the depth and dim to scale accuracy and speed.
net = ConvMixer(256, 16, kernel_size=args.convkernel, patch_size=1, n_classes=10)
elif args.net=="mlpmixer":
from models.mlpmixer import MLPMixer
net = MLPMixer(
image_size = 32,
channels = 3,
patch_size = args.patch,
dim = 512,
depth = 6,
num_classes = 10
)
elif args.net=="vit_small":
from models.vit_small import ViT
net = ViT(
image_size = size,
patch_size = args.patch,
num_classes = 10,
dim = int(args.dimhead),
depth = 6,
heads = 8,
mlp_dim = 512,
dropout = 0.1,
emb_dropout = 0.1
)
elif args.net=="vit_tiny":
from models.vit_small import ViT
net = ViT(
image_size = size,
patch_size = args.patch,
num_classes = 10,
dim = int(args.dimhead),
depth = 4,
heads = 6,
mlp_dim = 256,
dropout = 0.1,
emb_dropout = 0.1
)
elif args.net=="simplevit":
from models.simplevit import SimpleViT
net = SimpleViT(
image_size = size,
patch_size = args.patch,
num_classes = 10,
dim = int(args.dimhead),
depth = 6,
heads = 8,
mlp_dim = 512
)
elif args.net=="vit":
# ViT for cifar10
net = ViT(
image_size = size,
patch_size = args.patch,
num_classes = 10,
dim = int(args.dimhead),
depth = 6,
heads = 8,
mlp_dim = 512,
dropout = 0.1,
emb_dropout = 0.1
)
elif args.net=="vit_timm":
import timm
net = timm.create_model("vit_base_patch16_384", pretrained=True)
net.head = nn.Linear(net.head.in_features, 10)
elif args.net=="cait":
from models.cait import CaiT
net = CaiT(
image_size = size,
patch_size = args.patch,
num_classes = 10,
dim = int(args.dimhead),
depth = 6, # depth of transformer for patch to patch attention only
cls_depth=2, # depth of cross attention of CLS tokens to patch
heads = 8,
mlp_dim = 512,
dropout = 0.1,
emb_dropout = 0.1,
layer_dropout = 0.05
)
elif args.net=="cait_small":
from models.cait import CaiT
net = CaiT(
image_size = size,
patch_size = args.patch,
num_classes = 10,
dim = int(args.dimhead),
depth = 6, # depth of transformer for patch to patch attention only
cls_depth=2, # depth of cross attention of CLS tokens to patch
heads = 6,
mlp_dim = 256,
dropout = 0.1,
emb_dropout = 0.1,
layer_dropout = 0.05
)
elif args.net=="swin":
from models.swin import swin_t
net = swin_t(window_size=args.patch,
num_classes=10,
downscaling_factors=(2,2,2,1))
# For Multi-GPU
#if 'cuda' in device:
#print(device)
#print("using data parallel")
#net = torch.nn.DataParallel(net) # make parallel
net = net.to(device)
net = DistributedDataParallel(net, device_ids=[device])
cudnn.benchmark = True
if args.resume:
# Load checkpoint.
#print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net))
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
# Loss is CE
criterion = nn.CrossEntropyLoss()
if args.opt == "adam":
optimizer = optim.Adam(net.parameters(), lr=args.lr)
elif args.opt == "sgd":
optimizer = optim.SGD(net.parameters(), lr=args.lr)
# use cosine scheduling
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs)
##### Training
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
def train(epoch):
global num_trained_samples, global_steps
#print('\nEpoch: %d' % epoch)
train_sampler.set_epoch(epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
# Train with amp
with torch.cuda.amp.autocast(enabled=use_amp):
outputs = net(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
num_trained_samples += targets.size(0)
global_steps += 1
learning_rate = f'{optimizer.param_groups[0]["lr"]:.9f}'
loss_str = "%.4f" % (train_loss/(batch_idx+1))
acc = 100.*correct/total
step_output = f'[PerfLog] {{"event": "STEP_END", "value": {{"epoch": {epoch+1}, "global_steps": {global_steps},"loss": {loss_str},"accuracy":{acc:.4f},"num_trained_samples": {num_trained_samples}, "learning_rate": {learning_rate}}}}}'
log_file.write(step_output + '\n')
print(f'rank {local_rank}: ' + step_output)
#progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
# % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
return train_loss/(batch_idx+1)
##### Validation
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
#progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
# % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
torch.cuda.synchronize()
t = torch.tensor([total, correct], device='cuda')
dist.all_reduce(t)
total = t[0]
correct = t[1]
# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
if dist.get_rank() == 0:
#print('Saving..')
state = {"model": net.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict()}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch))
best_acc = acc
#os.makedirs("log", exist_ok=True)
#content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}'
#print(content)
#with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender:
# appender.write(content + "\n")
return test_loss, acc, total
#list_loss = []
#list_acc = []
#if usewandb:
# wandb.watch(net)
training_start = time.time()
training_only = 0
net.cuda()
for epoch in range(start_epoch, args.n_epochs):
start = time.time()
trainloss = train(epoch)
epoch_time = time.time() - start
training_only += epoch_time
start = time.time()
val_loss, acc, total= test(epoch)
eval_time = time.time() - start
eval_output = f'[PerfLog] {{"event": "EVALUATE_END", "value": {{"global_steps": {global_steps},"eval_loss": {val_loss:.4f},"eval_mlm_accuracy":{acc:.4f},"eval_time": {eval_time:.4f},"epoch_time":{epoch_time:.4f},"num_eval_samples":{total}}}}}'
log_file.write(eval_output + '\n')
print(f'rank {local_rank}: ' + eval_output)
if acc >= target_acc:
final_acc = acc
break
scheduler.step(epoch-1) # step cosine scheduling
#list_loss.append(val_loss)
#list_acc.append(acc)
# Log training..
#if usewandb:
# wandb.log({'epoch': epoch, 'train_loss': trainloss, 'val_loss': val_loss, "val_acc": acc, "lr": optimizer.param_groups[0]["lr"],
# "epoch_time": time.time()-start})
# Write out csv..
#with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f:
# writer = csv.writer(f, lineterminator='\n')
# writer.writerow(list_loss)
# writer.writerow(list_acc)
#print(list_loss)
train_time = time.time() - training_start
samples_sec = num_trained_samples / training_only
train_output = f'[PerfLog] {{"event": "TRAIN_END", "value": {{"accuracy":{final_acc:.4f},"train_time":{train_time:.4f},"samples/sec: {samples_sec:.4f}","num_trained_samples":{num_trained_samples}}}}}'
log_file.write(train_output + '\n')
print(f'rank {local_rank}: ' + train_output)
log_file.close()
# writeout wandb
#if usewandb:
# wandb.save("wandb_{}.h5".format(args.net))
# -*- coding: utf-8 -*-
'''Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math
import torch.nn as nn
import torch.nn.init as init
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
for inputs, targets in dataloader:
for i in range(3):
mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
def init_params(net):
'''Init layer parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
try:
_, term_width = os.popen('stty size', 'r').read().split()
except:
term_width = 80
term_width = int(term_width)
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment