Commit bffed0fe authored by dengjb's avatar dengjb
Browse files

update

parents
"""
Common image operations
Reference: https://github.com/hendrycks/robustness
Hacked together for STR by: Rowel Atienza
"""
import cv2
import numpy as np
from scipy.ndimage import zoom as scizoom
def clipped_zoom(img, zoom_factor):
h = img.shape[1]
# ceil crop height(= crop width)
ch = int(np.ceil(h / float(zoom_factor)))
top = (h - ch) // 2
img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
# trim off any extra pixels
trim_top = (img.shape[0] - h) // 2
return img[trim_top:trim_top + h, trim_top:trim_top + h]
def disk(radius, alias_blur=0.1, dtype=np.float32):
if radius <= 8:
coords = np.arange(-8, 8 + 1)
ksize = (3, 3)
else:
coords = np.arange(-radius, radius + 1)
ksize = (5, 5)
x, y = np.meshgrid(coords, coords)
aliased_disk = np.asarray((x ** 2 + y ** 2) <= radius ** 2, dtype=dtype)
aliased_disk /= np.sum(aliased_disk)
# supersample disk to antialias
return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
def plasma_fractal(mapsize=256, wibbledecay=3, rng=None):
"""
Generate a heightmap using diamond-square algorithm.
Return square 2d array, side length 'mapsize', of floats in range 0-255.
'mapsize' must be a power of two.
"""
assert (mapsize & (mapsize - 1) == 0)
maparray = np.empty((mapsize, mapsize), dtype=np.float_)
maparray[0, 0] = 0
stepsize = mapsize
wibble = 100
if rng is None:
rng = np.random.default_rng()
def wibbledmean(array):
return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape)
def fillsquares():
"""For each square of points stepsize apart,
calculate middle value as mean of points + wibble"""
cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
squareaccum += np.roll(squareaccum, shift=-1, axis=1)
maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
def filldiamonds():
"""For each diamond of points stepsize apart,
calculate middle value as mean of points + wibble"""
drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
ltsum = ldrsum + lulsum
maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
ttsum = tdrsum + tulsum
maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
while stepsize >= 2:
fillsquares()
filldiamonds()
stepsize //= 2
wibble /= wibbledecay
maparray -= maparray.min()
return maparray / maparray.max()
\ No newline at end of file
import math
from io import BytesIO
import cv2
import numpy as np
from PIL import Image, ImageOps, ImageDraw
from pkg_resources import resource_filename
from wand.image import Image as WandImage
import albumentations as alb
from .ops import plasma_fractal
class Fog(alb.ImageOnlyTransform):
def __init__(self, mag=-1, always_apply=False, p=1.):
super().__init__(always_apply=always_apply, p=p)
self.rng = np.random.default_rng()
self.mag = mag
def apply(self, img, **params):
img = Image.fromarray(img.astype(np.uint8))
w, h = img.size
c = [(1.5, 2), (2., 2), (2.5, 1.7)]
if self.mag < 0 or self.mag >= len(c):
index = self.rng.integers(0, len(c))
else:
index = self.mag
c = c[index]
n_channels = len(img.getbands())
isgray = n_channels == 1
img = np.asarray(img) / 255.
max_val = img.max()
# Make sure fog image is at least twice the size of the input image
max_size = 2 ** math.ceil(math.log2(max(w, h)) + 1)
fog = c[0] * plasma_fractal(mapsize=max_size, wibbledecay=c[1], rng=self.rng)[:h, :w][..., np.newaxis]
# x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis]
# return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
if isgray:
fog = np.squeeze(fog)
else:
fog = np.repeat(fog, 3, axis=2)
img += fog
img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
return img.astype(np.uint8)
class Frost(alb.ImageOnlyTransform):
def __init__(self, mag=-1, always_apply=False, p=1.):
super().__init__(always_apply=always_apply, p=p)
self.rng = np.random.default_rng()
self.mag = mag
def apply(self, img, **params):
img = Image.fromarray(img.astype(np.uint8))
w, h = img.size
c = [(0.78, 0.22), (0.64, 0.36), (0.5, 0.5)]
if self.mag < 0 or self.mag >= len(c):
index = self.rng.integers(0, len(c))
else:
index = self.mag
c = c[index]
filename = [resource_filename(__name__, 'frost/frost1.png'),
resource_filename(__name__, 'frost/frost2.png'),
resource_filename(__name__, 'frost/frost3.png'),
resource_filename(__name__, 'frost/frost4.jpg'),
resource_filename(__name__, 'frost/frost5.jpg'),
resource_filename(__name__, 'frost/frost6.jpg')]
index = self.rng.integers(0, len(filename))
filename = filename[index]
# Some images have transparency. Remove alpha channel.
frost = Image.open(filename).convert('RGB')
# Resize the frost image to match the input image's dimensions
f_w, f_h = frost.size
if w / h > f_w / f_h:
f_h = round(f_h * w / f_w)
f_w = w
else:
f_w = round(f_w * h / f_h)
f_h = h
frost = np.asarray(frost.resize((f_w, f_h)))
# randomly crop
y_start, x_start = self.rng.integers(0, f_h - h + 1), self.rng.integers(0, f_w - w + 1)
frost = frost[y_start:y_start + h, x_start:x_start + w]
n_channels = len(img.getbands())
isgray = n_channels == 1
img = np.asarray(img)
if isgray:
img = np.expand_dims(img, axis=2)
img = np.repeat(img, 3, axis=2)
img = np.clip(np.round(c[0] * img + c[1] * frost), 0, 255)
img = img.astype(np.uint8)
if isgray:
img = np.squeeze(img)
return img
class Snow(alb.ImageOnlyTransform):
def __init__(self, mag=-1, always_apply=False, p=1.):
super().__init__(always_apply=always_apply, p=p)
self.rng = np.random.default_rng()
self.mag = mag
def apply(self, img, **params):
img = Image.fromarray(img.astype(np.uint8))
w, h = img.size
c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
(0.2, 0.3, 2, 0.5, 12, 4, 0.7),
(0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
if self.mag < 0 or self.mag >= len(c):
index = self.rng.integers(0, len(c))
else:
index = self.mag
c = c[index]
n_channels = len(img.getbands())
isgray = n_channels == 1
img = np.asarray(img, dtype=np.float32) / 255.
if isgray:
img = np.expand_dims(img, axis=2)
img = np.repeat(img, 3, axis=2)
snow_layer = self.rng.normal(size=img.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
# snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
snow_layer[snow_layer < c[3]] = 0
snow_layer = Image.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
output = BytesIO()
snow_layer.save(output, format='PNG')
snow_layer = WandImage(blob=output.getvalue())
snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=self.rng.uniform(-135, -45))
snow_layer = cv2.imdecode(np.frombuffer(snow_layer.make_blob(), np.uint8),
cv2.IMREAD_UNCHANGED) / 255.
# snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_BGR2RGB)
snow_layer = snow_layer[..., np.newaxis]
img = c[6] * img
gray_img = (1 - c[6]) * np.maximum(img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(h, w, 1) * 1.5 + 0.5)
img += gray_img
img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
img = img.astype(np.uint8)
if isgray:
img = np.squeeze(img)
return img
class Rain(alb.ImageOnlyTransform):
def __init__(self, mag=-1, always_apply=False, p=1.):
super().__init__(always_apply=always_apply, p=p)
self.rng = np.random.default_rng()
self.mag = mag
def apply(self, img, **params):
img = Image.fromarray(img.astype(np.uint8))
img = img.copy()
w, h = img.size
n_channels = len(img.getbands())
isgray = n_channels == 1
line_width = self.rng.integers(1, 2)
c = [50, 70, 90]
if self.mag < 0 or self.mag >= len(c):
index = 0
else:
index = self.mag
c = c[index]
n_rains = self.rng.integers(c, c + 20)
slant = self.rng.integers(-60, 60)
fillcolor = 200 if isgray else (200, 200, 200)
draw = ImageDraw.Draw(img)
max_length = min(w, h, 10)
for i in range(1, n_rains):
length = self.rng.integers(5, max_length)
x1 = self.rng.integers(0, w - length)
y1 = self.rng.integers(0, h - length)
x2 = x1 + length * math.sin(slant * math.pi / 180.)
y2 = y1 + length * math.cos(slant * math.pi / 180.)
x2 = int(x2)
y2 = int(y2)
draw.line([(x1, y1), (x2, y2)], width=line_width, fill=fillcolor)
img = np.asarray(img).astype(np.uint8)
return img
class Shadow(alb.ImageOnlyTransform):
def __init__(self, mag=-1, always_apply=False, p=1.):
super().__init__(always_apply=always_apply, p=p)
self.rng = np.random.default_rng()
self.mag = mag
def apply(self, img, **params):
img = Image.fromarray(img.astype(np.uint8))
# img = img.copy()
w, h = img.size
n_channels = len(img.getbands())
isgray = n_channels == 1
c = [64, 96, 128]
if self.mag < 0 or self.mag >= len(c):
index = 0
else:
index = self.mag
c = c[index]
img = img.convert('RGBA')
overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(overlay)
transparency = self.rng.integers(c, c + 32)
x1 = self.rng.integers(0, w // 2)
y1 = 0
x2 = self.rng.integers(w // 2, w)
y2 = 0
x3 = self.rng.integers(w // 2, w)
y3 = h - 1
x4 = self.rng.integers(0, w // 2)
y4 = h - 1
draw.polygon([(x1, y1), (x2, y2), (x3, y3), (x4, y4)], fill=(0, 0, 0, transparency))
img = Image.alpha_composite(img, overlay)
img = img.convert("RGB")
if isgray:
img = ImageOps.grayscale(img)
img = np.asarray(img).astype(np.uint8)
return img
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import cv2
import numpy as np
import torch
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
"""
same output as PIL.ImageOps.autocontrast
"""
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
"""
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
"""
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
"""
like PIL, rotate by degree, not radians
"""
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
"""
same output as PIL.ImageOps.posterize
"""
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
"""
same output as PIL.ImageEnhance.Color
"""
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = np.float32(
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
) * factor + np.float32([[0.114], [0.587], [0.299]])
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = (
np.array([(el - mean) * factor + mean for el in range(256)])
.clip(0, 255)
.astype(np.uint8)
)
out = table[img]
return out
def brightness_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
"""
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
"""
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def posterize_func(img, bits):
"""
same output as PIL.ImageOps.posterize
"""
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level,)
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level,)
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
"Identity": identity_func,
"AutoContrast": autocontrast_func,
"Equalize": equalize_func,
"Rotate": rotate_func,
"Solarize": solarize_func,
"Color": color_func,
"Contrast": contrast_func,
"Brightness": brightness_func,
"Sharpness": sharpness_func,
"ShearX": shear_x_func,
"TranslateX": translate_x_func,
"TranslateY": translate_y_func,
"Posterize": posterize_func,
"ShearY": shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
"Identity": none_level_to_args,
"AutoContrast": none_level_to_args,
"Equalize": none_level_to_args,
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
"Solarize": solarize_level_to_args(MAX_LEVEL),
"Color": enhance_level_to_args(MAX_LEVEL),
"Contrast": enhance_level_to_args(MAX_LEVEL),
"Brightness": enhance_level_to_args(MAX_LEVEL),
"Sharpness": enhance_level_to_args(MAX_LEVEL),
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"Posterize": posterize_level_to_args(MAX_LEVEL),
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
class VideoRandomAugment(object):
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
self.N = N
self.M = M
self.p = p
self.tensor_in_tensor_out = tensor_in_tensor_out
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
return [(op, self.M) for op in sampled_ops]
def __call__(self, frames):
assert (
frames.shape[-1] == 3
), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
if self.tensor_in_tensor_out:
frames = frames.numpy().astype(np.uint8)
num_frames = frames.shape[0]
ops = num_frames * [self.get_random_ops()]
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
frames = torch.stack(
list(map(self._aug, frames, ops, apply_or_not)), dim=0
).float()
return frames
def _aug(self, img, ops, apply_or_not):
for i, (name, level) in enumerate(ops):
if not apply_or_not[i]:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return torch.from_numpy(img)
if __name__ == "__main__":
a = RandomAugment()
img = np.random.randn(32, 32, 3)
a(img)
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from unimernet.runners.runner_base import RunnerBase
from unimernet.runners.runner_iter import RunnerIter
__all__ = ["RunnerBase", "RunnerIter"]
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import json
import logging
import os
import time
from pathlib import Path
import torch
import torch.distributed as dist
import webdataset as wds
from unimernet.common.dist_utils import (
download_cached_file,
get_rank,
get_world_size,
is_main_process,
main_process,
)
from unimernet.common.registry import registry
from unimernet.common.utils import is_url
from unimernet.datasets.data_utils import reorg_datasets_by_split, concat_datasets
from unimernet.datasets.datasets.dataloader_utils import (
IterLoader,
MultiIterLoader,
ConcatLoader,
PrefetchLoader,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.dataset import ChainDataset
@registry.register_runner("runner_base")
class RunnerBase:
"""
A runner class to train and evaluate a model given a task and datasets.
The runner uses pytorch distributed data parallel by default. Future release
will support other distributed frameworks.
"""
def __init__(self, cfg, task, model, datasets, job_id):
self.config = cfg
self.job_id = job_id
self.task = task
self.datasets = datasets
self._model = model
self._wrapped_model = None
self._device = None
self._optimizer = None
self._scaler = None
self._dataloaders = None
self._lr_sched = None
self.start_epoch = 0
# self.setup_seeds()
self.setup_output_dir()
@property
def device(self):
if self._device is None:
self._device = torch.device(self.config.run_cfg.device)
return self._device
@property
def milestone(self):
return self.config.run_cfg.get("milestone", None)
@property
def use_distributed(self):
return self.config.run_cfg.distributed
@property
def model(self):
"""
A property to get the DDP-wrapped model on the device.
"""
# move model to device
if self._model.device != self.device:
self._model = self._model.to(self.device)
# distributed training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=False
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
# TODO make optimizer class and configurations
if self._optimizer is None:
num_parameters = 0
p_wd, p_non_wd = [], []
for n, p in self.model.named_parameters():
if not p.requires_grad:
continue # frozen weights
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
p_non_wd.append(p)
else:
p_wd.append(p)
num_parameters += p.data.nelement()
logging.info("number of trainable parameters: %d" % num_parameters)
optim_params = [
{
"params": p_wd,
"weight_decay": float(self.config.run_cfg.weight_decay),
},
{"params": p_non_wd, "weight_decay": 0},
]
beta2 = self.config.run_cfg.get("beta2", 0.999)
self._optimizer = torch.optim.AdamW(
optim_params,
lr=float(self.config.run_cfg.init_lr),
weight_decay=float(self.config.run_cfg.weight_decay),
betas=(0.9, beta2),
)
return self._optimizer
@property
def scaler(self):
amp = self.config.run_cfg.get("amp", False)
if amp:
if self._scaler is None:
self._scaler = torch.cuda.amp.GradScaler()
return self._scaler
@property
def lr_scheduler(self):
"""
A property to get and create learning rate scheduler by split just in need.
"""
if self._lr_sched is None:
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
# max_epoch = self.config.run_cfg.max_epoch
max_epoch = self.max_epoch
# min_lr = self.config.run_cfg.min_lr
min_lr = self.min_lr
# init_lr = self.config.run_cfg.init_lr
init_lr = self.init_lr
# optional parameters
decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
iters_per_epoch = self.config.run_cfg.get("iters_per_inner_epoch", len(self.train_loader))
self._lr_sched = lr_sched_cls(
optimizer=self.optimizer,
max_epoch=max_epoch,
min_lr=min_lr,
init_lr=init_lr,
decay_rate=decay_rate,
warmup_start_lr=warmup_start_lr,
warmup_steps=warmup_steps,
iters_per_epoch=iters_per_epoch,
)
return self._lr_sched
@property
def dataloaders(self) -> dict:
"""
A property to get and create dataloaders by split just in need.
If no train_dataset_ratio is provided, concatenate map-style datasets and
chain wds.DataPipe datasets separately. Training set becomes a tuple
(ConcatDataset, ChainDataset), both are optional but at least one of them is
required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
If train_dataset_ratio is provided, create a MultiIterLoader to sample
each dataset by ratios during training.
Currently do not support multiple datasets for validation and test.
Returns:
dict: {split_name: (tuples of) dataloader}
"""
if self._dataloaders is None:
# reoganize datasets by split and concatenate/chain if necessary
datasets = reorg_datasets_by_split(self.datasets)
self.datasets = concat_datasets(datasets)
self.datasets = {
k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
}
# print dataset statistics after concatenation/chaining
for split_name in self.datasets:
if isinstance(self.datasets[split_name], tuple) or isinstance(
self.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
else 0
for d in self.datasets[split_name]
]
)
else:
if hasattr(self.datasets[split_name], "__len__"):
# a single map-style dataset
num_records = len(self.datasets[split_name])
else:
# a single wds.DataPipeline
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
# create dataloaders
split_names = sorted(self.datasets.keys())
datasets = [self.datasets[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
batch_sizes = [
self.config.run_cfg.batch_size_train
if split == "train"
else self.config.run_cfg.batch_size_eval
for split in split_names
]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
dataloaders = self.create_loaders(
datasets=datasets,
num_workers=self.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
# concat=True
)
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
return self._dataloaders
@property
def cuda_enabled(self):
return self.device.type == "cuda"
@property
def max_epoch(self):
return int(self.config.run_cfg.max_epoch)
@property
def log_freq(self):
log_freq = self.config.run_cfg.get("log_freq", 50)
return int(log_freq)
@property
def init_lr(self):
return float(self.config.run_cfg.init_lr)
@property
def min_lr(self):
return float(self.config.run_cfg.min_lr)
@property
def accum_grad_iters(self):
return int(self.config.run_cfg.get("accum_grad_iters", 1))
@property
def valid_splits(self):
valid_splits = self.config.run_cfg.get("valid_splits", [])
if len(valid_splits) == 0:
logging.info("No validation splits found.")
return valid_splits
@property
def test_splits(self):
test_splits = self.config.run_cfg.get("test_splits", [])
return test_splits
@property
def train_splits(self):
train_splits = self.config.run_cfg.get("train_splits", [])
if len(train_splits) == 0:
logging.info("Empty train splits.")
return train_splits
@property
def evaluate_only(self):
"""
Set to True to skip training.
"""
return self.config.run_cfg.evaluate
@property
def use_dist_eval_sampler(self):
return self.config.run_cfg.get("use_dist_eval_sampler", True)
@property
def resume_ckpt_path(self):
return self.config.run_cfg.get("resume_ckpt_path", None)
@property
def train_loader(self):
train_dataloader = self.dataloaders["train"]
return train_dataloader
def setup_output_dir(self):
lib_root = Path(registry.get_path("library_root"))
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
result_dir = output_dir / "result"
output_dir.mkdir(parents=True, exist_ok=True)
result_dir.mkdir(parents=True, exist_ok=True)
registry.register_path("result_dir", str(result_dir))
registry.register_path("output_dir", str(output_dir))
self.result_dir = result_dir
self.output_dir = output_dir
def train(self):
start_time = time.time()
best_agg_metric = 0
best_epoch = 0
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
for cur_epoch in range(self.start_epoch, self.max_epoch):
# training phase
if not self.evaluate_only:
logging.info("Start training")
train_stats = self.train_epoch(cur_epoch)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))
val_log = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric and split_name == "eval":
best_epoch, best_agg_metric = cur_epoch, agg_metrics
self._save_checkpoint(cur_epoch, is_best=True)
val_log.update({"best_epoch": best_epoch})
self.log_stats(val_log, split_name)
if self.evaluate_only:
break
if self.milestone and cur_epoch + 1 in self.milestone:
self._save_checkpoint(cur_epoch)
self._save_checkpoint(cur_epoch, latest=True)
dist.barrier()
# testing phase
test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Training time {}".format(total_time_str))
def evaluate(self, cur_epoch="best", skip_reload=False):
test_logs = dict()
if len(self.test_splits) > 0:
for split_name in self.test_splits:
test_logs[split_name] = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
)
return test_logs
def train_epoch(self, epoch):
# train
self.model.train()
return self.task.train_epoch(
epoch=epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
)
@torch.no_grad()
def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
"""
Evaluate the model on a given split.
Args:
split_name (str): name of the split to evaluate on.
cur_epoch (int): current epoch.
skip_reload_best (bool): whether to skip reloading the best checkpoint.
During training, we will reload the best checkpoint for validation.
During testing, we will use provided weights and skip reloading the best checkpoint .
"""
data_loader = self.dataloaders.get(split_name, None)
assert data_loader, "data_loader for split {} is None.".format(split_name)
# TODO In validation, you need to compute loss as well as metrics
# TODO consider moving to model.before_evaluation()
model = self.unwrap_dist_model(self.model)
if not skip_reload and cur_epoch == "best":
model = self._reload_best_model(model)
model.eval()
self.task.before_evaluation(
model=model,
dataset=self.datasets[split_name],
)
results = self.task.evaluation(model, data_loader)
if results is not None:
return self.task.after_evaluation(
val_result=results,
split_name=split_name,
epoch=cur_epoch,
)
def unwrap_dist_model(self, model):
if self.use_distributed:
return model.module
else:
return model
def create_loaders(
self,
datasets,
num_workers,
batch_sizes,
is_trains,
collate_fns,
concat=False
):
"""
Create dataloaders for training and validation.
"""
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
# create a single dataloader for each split
if isinstance(dataset, ChainDataset) or isinstance(
dataset, wds.DataPipeline
):
# wds.WebdDataset instance are chained together
# webdataset.DataPipeline has its own sampler and collate_fn
loader = iter(
DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
)
)
else:
# map-style dataset are concatenated together
# setup distributed sampler
if self.use_distributed:
sampler = DistributedSampler(
dataset,
shuffle=is_train,
num_replicas=get_world_size(),
rank=get_rank(),
)
if not self.use_dist_eval_sampler:
# e.g. retrieval evaluation
sampler = sampler if is_train else None
else:
sampler = None
loader = DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and is_train,
collate_fn=collate_fn,
drop_last=True if is_train else False,
)
loader = PrefetchLoader(loader)
if is_train:
loader = IterLoader(loader, use_distributed=self.use_distributed)
return loader
loaders = []
for dataset, bsz, is_train, collate_fn in zip(
datasets, batch_sizes, is_trains, collate_fns
):
if isinstance(dataset, list) or isinstance(dataset, tuple):
if not concat:
sample_ratios = [d.sample_ratio for d in dataset]
loader = MultiIterLoader(
loaders=[
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
for i, d in enumerate(dataset)
],
ratios=sample_ratios
)
else:
loader = ConcatLoader(
loaders=[
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
for i, d in enumerate(dataset)
]
)
else:
loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
loaders.append(loader)
return loaders
@main_process
def _save_checkpoint(self, cur_epoch, is_best=False, latest=False):
"""
Save the checkpoint at the current epoch.
"""
assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
model_no_ddp = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
}
state_dict = model_no_ddp.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"epoch": cur_epoch,
}
if is_best:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best"),
)
elif latest:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("latest"),
)
else:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format(cur_epoch+1),
)
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch+1, save_to))
torch.save(save_obj, save_to)
def _reload_best_model(self, model):
"""
Load the best checkpoint for evaluation.
"""
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
logging.info("Loading checkpoint from {}.".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location="cpu")
try:
model.load_state_dict(checkpoint["model"])
except RuntimeError as e:
logging.warning(
"""
Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
Trying to load the model with strict=False.
"""
)
model.load_state_dict(checkpoint["model"], strict=False)
return model
def _load_checkpoint(self, url_or_filename):
"""
Resume from a checkpoint.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_epoch = checkpoint["epoch"]
logging.info("Resume checkpoint from {}".format(url_or_filename))
@main_process
def log_stats(self, stats, split_name):
if isinstance(stats, dict):
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")
elif isinstance(stats, list):
pass
@main_process
def log_config(self):
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import os
import time
import torch
import torch.distributed as dist
import webdataset as wds
from unimernet.common.dist_utils import download_cached_file, is_main_process, main_process
from unimernet.common.registry import registry
from unimernet.common.utils import is_url
from unimernet.datasets.data_utils import reorg_datasets_by_split
from unimernet.runners.runner_base import RunnerBase
from torch.utils.data.dataset import ChainDataset
@registry.register_runner("runner_iter")
class RunnerIter(RunnerBase):
"""
Run training based on the number of iterations. This is common when
the training dataset size is large. Underhood logic is similar to
epoch-based training by considering every #iters_per_inner_epoch as an
inner epoch.
In iter-based runner, after every #iters_per_inner_epoch steps, we
1) do a validation epoch;
2) schedule the learning rate;
3) save the checkpoint.
We refer every #iters_per_inner_epoch steps as an inner epoch.
"""
def __init__(self, cfg, task, model, datasets, job_id):
super().__init__(cfg, task, model, datasets, job_id)
self.start_iters = 0
self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
assert self.max_iters > 0, "max_iters must be greater than 0."
self.iters_per_inner_epoch = int(
self.config.run_cfg.get("iters_per_inner_epoch", -1)
)
assert (
self.iters_per_inner_epoch > 0
), "iters_per_inner_epoch must be greater than 0."
@property
def max_epoch(self):
return int(self.max_iters / self.iters_per_inner_epoch)
@property
def cur_epoch(self):
try:
return self.train_loader.epoch
except AttributeError:
# pipeline data (e.g. LAION) is streaming, have no concept of epoch
return 0
def _progress(self, cur_iters):
return "{}_iters={}".format(self.cur_epoch, cur_iters)
def train(self):
start_time = time.time()
best_agg_metric = 0
best_iters = 0
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
cur_epoch = 0
for start_iters in range(
self.start_iters, self.max_iters, self.iters_per_inner_epoch
):
end_iters = start_iters + self.iters_per_inner_epoch
# training phase
if not self.evaluate_only:
logging.info(
"Start training, max_iters={}, in total {} inner epochs.".format(
self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
)
)
train_stats = self.train_iters(self.cur_epoch, start_iters)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))
val_log = self.eval_epoch(
split_name=split_name, cur_epoch=self._progress(end_iters)
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric and split_name == "eval":
best_iters, best_agg_metric = end_iters, agg_metrics
self._save_checkpoint(end_iters, is_best=True)
val_log.update({"best_iters": best_iters})
self.log_stats(val_log, split_name)
# print evaluation metric
print(f"bleu:{val_log['bleu']:.6f}, edit_distance:{val_log['edit_distance']:.6f}, token_accuracy:{val_log['token_accuracy']:.6f} ")
print("="*80)
if self.evaluate_only:
break
if self.milestone and cur_epoch + 1 in self.milestone:
self._save_checkpoint(cur_epoch)
self._save_checkpoint(end_iters, latest=True)
dist.barrier()
cur_epoch += 1
# testing phase
self.evaluate(cur_epoch=self.cur_epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Training time {}".format(total_time_str))
def train_iters(self, epoch, start_iters):
# train by iterations
self.model.train()
return self.task.train_iters(
epoch=epoch,
start_iters=start_iters,
iters_per_inner_epoch=self.iters_per_inner_epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
)
@main_process
def _save_checkpoint(self, cur_iters, is_best=False, latest=False):
# only save the params requires gradient
assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
unwrapped_model = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in unwrapped_model.named_parameters()
}
state_dict = unwrapped_model.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"iters": cur_iters,
}
if is_best:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best"),
)
elif latest:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("latest"),
)
else:
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format(cur_iters),
)
logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
torch.save(save_obj, save_to)
def _load_checkpoint(self, url_or_filename):
"""
Resume from a checkpoint.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_iters = checkpoint["iters"] + 1
logging.info("Resume checkpoint from {}".format(url_or_filename))
@property
def dataloaders(self) -> dict:
"""
A property to get and create dataloaders by split just in need.
If no train_dataset_ratio is provided, concatenate map-style datasets and
chain wds.DataPipe datasets separately. Training set becomes a tuple
(ConcatDataset, ChainDataset), both are optional but at least one of them is
required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
If train_dataset_ratio is provided, create a MultiIterLoader to sample
each dataset by ratios during training.
Currently do not support multiple datasets for validation and test.
Returns:
dict: {split_name: (tuples of) dataloader}
"""
if self._dataloaders is None:
# reoganize datasets by split and concatenate/chain if necessary
self.datasets = reorg_datasets_by_split(self.datasets)
# to keep the same structure as return value of concat_datasets
self.datasets = {
k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
}
# print dataset statistics after concatenation/chaining
for split_name in self.datasets:
if isinstance(self.datasets[split_name], tuple) or isinstance(
self.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
else 0
for d in self.datasets[split_name]
]
)
else:
try:
# a single map-style dataset
num_records = len(self.datasets[split_name])
except TypeError:
# a single wds.DataPipeline or ChainDataset
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
# create dataloaders
split_names = sorted(self.datasets.keys())
datasets = [self.datasets[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
batch_sizes = [
self.config.run_cfg.batch_size_train
if split == "train"
else self.config.run_cfg.batch_size_eval
for split in split_names
]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
dataloaders = self.create_loaders(
datasets=datasets,
num_workers=self.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
)
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
return self._dataloaders
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from unimernet.common.registry import registry
from unimernet.tasks.base_task import BaseTask
from unimernet.tasks.unimernet_train import UniMERNet_Train
def setup_task(cfg):
assert "task" in cfg.run_cfg, "Task name must be provided."
task_name = cfg.run_cfg.task
task = registry.get_task_class(task_name).setup_task(cfg=cfg)
assert task is not None, "Task {} not properly registered.".format(task_name)
return task
__all__ = [
"BaseTask",
"UniMERNet_Train",
]
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import torch
import torch.distributed as dist
from unimernet.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
from unimernet.common.logger import MetricLogger, SmoothedValue
from unimernet.common.registry import registry
from unimernet.datasets.data_utils import prepare_sample
class BaseTask:
def __init__(self, **kwargs):
super().__init__()
self.inst_id_key = "instance_id"
@classmethod
def setup_task(cls, **kwargs):
return cls()
def build_model(self, cfg):
model_config = cfg.model_cfg
model_cls = registry.get_model_class(model_config.arch)
return model_cls.from_config(model_config)
def build_datasets(self, cfg):
"""
Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
Download dataset and annotations automatically if not exist.
Args:
cfg (common.config.Config): _description_
Returns:
dict: Dictionary of torch.utils.data.Dataset objects by split.
"""
datasets = dict()
datasets_config = cfg.datasets_cfg
assert len(datasets_config) > 0, "At least one dataset has to be specified."
for name in datasets_config:
dataset_config = datasets_config[name]
builder = registry.get_builder_class(name)(dataset_config)
dataset = builder.build_datasets()
if "train" in dataset and "sample_ratio" in dataset_config:
dataset["train"].sample_ratio = float(dataset_config.sample_ratio)
datasets[name] = dataset
return datasets
def train_step(self, model, samples):
loss_dict = model(samples)
loss = loss_dict["loss"]
return loss, loss_dict
def valid_step(self, model, samples):
raise NotImplementedError
def before_evaluation(self, model, dataset, **kwargs):
model.before_evaluation(dataset=dataset, task_type=type(self))
def after_evaluation(self, **kwargs):
pass
def inference_step(self):
raise NotImplementedError
def evaluation(self, model, data_loader, cuda_enabled=True):
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation"
# TODO make it configurable
print_freq = 10
results = []
for samples in metric_logger.log_every(data_loader, print_freq, header):
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
eval_output = self.valid_step(model=model, samples=samples)
results.extend(eval_output)
if is_dist_avail_and_initialized():
dist.barrier()
return results
def train_epoch(
self,
epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
cuda_enabled=False,
log_freq=50,
accum_grad_iters=1,
):
return self._train_inner_loop(
epoch=epoch,
iters_per_epoch=len(data_loader),
model=model,
data_loader=data_loader,
optimizer=optimizer,
scaler=scaler,
lr_scheduler=lr_scheduler,
log_freq=log_freq,
cuda_enabled=cuda_enabled,
accum_grad_iters=accum_grad_iters,
)
def train_iters(
self,
epoch,
start_iters,
iters_per_inner_epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
cuda_enabled=False,
log_freq=50,
accum_grad_iters=1,
):
return self._train_inner_loop(
epoch=epoch,
start_iters=start_iters,
iters_per_epoch=iters_per_inner_epoch,
model=model,
data_loader=data_loader,
optimizer=optimizer,
scaler=scaler,
lr_scheduler=lr_scheduler,
log_freq=log_freq,
cuda_enabled=cuda_enabled,
accum_grad_iters=accum_grad_iters,
)
def _train_inner_loop(
self,
epoch,
iters_per_epoch,
model,
data_loader,
optimizer,
lr_scheduler,
scaler=None,
start_iters=None,
log_freq=50,
cuda_enabled=False,
accum_grad_iters=1,
):
"""
An inner training loop compatible with both epoch-based and iter-based training.
When using epoch-based, training stops after one epoch; when using iter-based,
training stops after #iters_per_epoch iterations.
"""
use_amp = scaler is not None
if not hasattr(data_loader, "__next__"):
# convert to iterator if not already
data_loader = iter(data_loader)
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
# if iter-based runner, schedule lr based on inner epoch.
logging.info(
"Start training epoch {}, {} iters per inner epoch.".format(
epoch, iters_per_epoch
)
)
header = "Train: data epoch: [{}]".format(epoch)
if start_iters is None:
# epoch-based runner
inner_epoch = epoch
else:
# In iter-based runner, we schedule the learning rate based on iterations.
inner_epoch = start_iters // iters_per_epoch
header = header + "; inner epoch [{}]".format(inner_epoch)
for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
# if using iter-based runner, we stop after iters_per_epoch iterations.
if i >= iters_per_epoch:
break
samples = next(data_loader)
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
samples.update(
{
"epoch": inner_epoch,
"num_iters_per_epoch": iters_per_epoch,
"iters": i,
}
)
lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
with torch.cuda.amp.autocast(enabled=use_amp):
loss, loss_dict = self.train_step(model=model, samples=samples)
loss /= accum_grad_iters # TODO: not affect loss_dict values for logging
# after_train_step()
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
# update gradients every accum_grad_iters iterations
if (i + 1) % accum_grad_iters == 0:
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
metric_logger.update(**loss_dict)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# after train_epoch()
# gather the stats from all processes
metric_logger.synchronize_between_processes()
logging.info("Averaged stats: " + str(metric_logger.global_avg()))
return {
k: "{:.3f}".format(meter.global_avg)
for k, meter in metric_logger.meters.items()
}
@staticmethod
def save_result(result, result_dir, filename, remove_duplicate=""):
import json
result_file = os.path.join(
result_dir, "%s_rank%d.json" % (filename, get_rank())
)
final_result_file = os.path.join(result_dir, "%s.json" % filename)
json.dump(result, open(result_file, "w"))
if is_dist_avail_and_initialized():
dist.barrier()
if is_main_process():
logging.warning("rank %d starts merging results." % get_rank())
# combine results from all processes
result = []
for rank in range(get_world_size()):
result_file = os.path.join(
result_dir, "%s_rank%d.json" % (filename, rank)
)
res = json.load(open(result_file, "r"))
result += res
if remove_duplicate:
result_new = []
id_list = []
for res in result:
if res[remove_duplicate] not in id_list:
id_list.append(res[remove_duplicate])
result_new.append(res)
result = result_new
json.dump(result, open(final_result_file, "w"))
print("result file saved to %s" % final_result_file)
return final_result_file
import torch
import evaluate
import random
from unimernet.common.registry import registry
from unimernet.tasks.base_task import BaseTask
from unimernet.common.dist_utils import main_process
import os.path as osp
import json
import numpy as np
from rapidfuzz.distance import Levenshtein
@registry.register_task("unimernet_train")
class UniMERNet_Train(BaseTask):
def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"):
super(UniMERNet_Train, self).__init__()
self.temperature = temperature
self.do_sample = do_sample
self.top_p = top_p
self.evaluate = evaluate
self.agg_metric = agg_metric
self.report_metric = report_metric
@classmethod
def setup_task(cls, cfg):
run_cfg = cfg.run_cfg
generate_cfg = run_cfg.generate_cfg
temperature = generate_cfg.get('temperature', .2)
do_sample = generate_cfg.get("do_sample", False)
top_p = generate_cfg.get("top_p", 0.95)
evaluate = run_cfg.evaluate
report_metric = run_cfg.get("report_metric", True)
agg_metric = run_cfg.get("agg_metric", "edit_distance")
return cls(
temperature=temperature,
do_sample=do_sample,
top_p=top_p,
evaluate=evaluate,
report_metric=report_metric,
agg_metric=agg_metric,
)
def valid_step(self, model, samples):
results = []
image, text = samples["image"], samples["text_input"]
preds = model.generate(
samples,
temperature=self.temperature,
do_sample=self.do_sample,
top_p=self.top_p
)
pred_tokens = preds["pred_tokens"]
pred_strs = preds["pred_str"]
pred_ids = preds["pred_ids"] # [b, n-1]
truth_inputs = model.tokenizer.tokenize(text)
truth_ids = truth_inputs["input_ids"][:, 1:]
truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"])
truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"])
ids = samples["id"]
for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs,
pred_ids, truth_tokens,
truth_strs, truth_ids, ids):
pred_id = pred_id.tolist()
truth_id = truth_id.tolist()
shape_diff = len(pred_id) - len(truth_id)
if shape_diff < 0:
pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff)
else:
truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff
pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id)
mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id)
tok_acc = (pred_id == truth_id)[mask].float().mean().item()
this_item = {
"pred_token": pred_token,
"pred_str": pred_str,
"truth_str": truth_str,
"truth_token": truth_token,
"token_acc": tok_acc,
"id": id_
}
results.append(this_item)
return results
def after_evaluation(self, val_result, split_name, epoch, **kwargs):
eval_result_file = self.save_result(
result=val_result,
result_dir=registry.get_path("result_dir"),
filename="{}_epoch{}".format(split_name, epoch),
remove_duplicate="id",
)
if self.report_metric:
metrics = self._report_metrics(
eval_result_file=eval_result_file, split_name=split_name
)
else:
metrics = {"agg_metrics": 0.0}
return metrics
@main_process
def _report_metrics(self, eval_result_file, split_name):
with open(eval_result_file) as f:
results = json.load(f)
edit_dists = []
all_pred_tokens = []
all_truth_tokens = []
all_pred_strs = []
all_truth_strs = []
token_accs = []
for result in results:
pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[
"truth_token"], result["truth_str"], result["token_acc"]
if len(truth_str) > 0:
norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str)
edit_dists.append(norm_edit_dist)
all_pred_tokens.append(pred_token)
all_truth_tokens.append([truth_token])
all_pred_strs.append(pred_str)
all_truth_strs.append(truth_str)
token_accs.append(tok_acc)
# bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens)
bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8))
bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs)
bleu_score = bleu_results['bleu']
edit_distance = np.mean(edit_dists)
token_accuracy = np.mean(token_accs)
eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy}
log_stats = {split_name: {k: v for k, v in eval_ret.items()}}
with open(
osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(log_stats) + "\n")
coco_res = {k: v for k, v in eval_ret.items()}
# agg_metrics = sum([v for v in eval_ret.values()])
if "edit" in self.agg_metric.lower(): # edit_distance
agg_metrics = (1 - edit_distance) * 100
elif "bleu" in self.agg_metric.lower(): # bleu_score
agg_metrics = bleu_score * 100
elif "token" in self.agg_metric.lower(): # token_accuracy
agg_metrics = token_accuracy * 100
else:
raise ValueError(f"Invalid metrics: '{self.agg_metric}'")
coco_res["agg_metrics"] = agg_metrics
return coco_res
import io
import os
import sys
import argparse
import numpy as np
import torch
import hashlib
import pypdfium2
import pandas as pd
import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas
import unimernet.tasks as tasks
from unimernet.common.config import Config
from unimernet.processors import load_processor
MAX_WIDTH = 872
MAX_HEIGHT = 1024
class ImageProcessor:
"""ImageProcessor class handles the loading of the model and processing of images."""
def __init__(self, cfg_path):
self.cfg_path = cfg_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.vis_processor = self.load_model_and_processor()
def load_model_and_processor(self):
# Load the model and visual processor from the configuration
args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg).to(self.device)
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
return model, vis_processor
def process_single_image(self, pil_image):
# Process an image and return the LaTeX string
image = self.vis_processor(pil_image).unsqueeze(0).to(self.device)
output = self.model.generate({"image": image})
pred = output["pred_str"][0]
return pred
@st.cache_data(show_spinner=False)
def read_markdown(path):
with open(path, "r", encoding="utf-8") as f:
data = f.read()
return data
def open_pdf(pdf_file):
stream = io.BytesIO(pdf_file.getvalue())
return pypdfium2.PdfDocument(stream)
@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=300):
# Extract an image from a PDF page
doc = open_pdf(pdf_file)
renderer = doc.render(
pypdfium2.PdfBitmap.to_pil,
page_indices=[page_num - 1],
scale=dpi / 72,
)
png = list(renderer)[0]
png_image = png.convert("RGB")
return png_image
@st.cache_data()
def get_uploaded_image(in_file):
# Load an uploaded image file
return Image.open(in_file).convert("RGB")
def resize_image(pil_image):
# Resize an image to fit within the MAX_WIDTH and MAX_HEIGHT
if pil_image is None:
return
pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
def display_image_cropped(pil_image, bbox):
# Display a cropped portion of an image
cropped_image = pil_image.crop(bbox)
st.image(cropped_image, use_column_width=True)
@st.cache_data()
def page_count_fn(pdf_file):
# Return the number of pages in a PDF
doc = open_pdf(pdf_file)
return len(doc)
def get_canvas_hash(pil_image):
return hashlib.md5(pil_image.tobytes()).hexdigest()
@st.cache_data()
def get_image_size(pil_image):
if pil_image is None:
return MAX_HEIGHT, MAX_WIDTH
height, width = pil_image.height, pil_image.width
return height, width
@st.cache_data(hash_funcs={ImageProcessor: id})
def infer_image(processor, pil_image, bbox):
# Perform inference on a cropped image
input_img = pil_image.crop(bbox)
pred = processor.process_single_image(input_img)
return pred
@st.cache_resource()
def load_image_processor(cfg_path):
processor = ImageProcessor(cfg_path)
return processor
def run_mode1():
"""Direct Recognition mode: recognize formulas directly from an image
"""
col1, col2 = st.columns([0.5, 0.5])
in_file = st.sidebar.file_uploader(
"Input Image:", type=["png", "jpg", "jpeg", "gif", "webp"]
)
if in_file is None:
st.stop()
filetype = in_file.type
pil_image = get_uploaded_image(in_file)
resize_image(pil_image)
with col1:
st.image(pil_image, use_column_width=True)
st.markdown(
"<h4 style='text-align: center; color: black;'>[Input: Image] </h4>",
unsafe_allow_html=True,
)
bbox_list = [(0, 0, pil_image.width, pil_image.height)]
with col2:
inferences = [infer_image(processor, pil_image, bbox) for bbox in bbox_list]
for idx, (bbox, inference) in enumerate(
zip(reversed(bbox_list), reversed(inferences))
):
st.latex(inference)
st.markdown(
"<h4 style='text-align: center; color: black;'>[Prediction: Rendered Image]</h4>",
unsafe_allow_html=True,
)
st.divider()
st.code(inference)
st.markdown(
"<h4 style='text-align: center; color: black;'>[Prediction: LaTeX Code]</h4>",
unsafe_allow_html=True,
)
def run_mode2():
"""Manual Selection mode: allows users to select formulas in an image or PDF for recognition.
"""
col1, col2 = st.columns([0.7, 0.3])
in_file = st.sidebar.file_uploader(
"PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
)
if in_file is None:
st.stop()
# Determine if the uploaded file is a PDF or an image
whole_image = False
if in_file.type == "application/pdf":
page_count = page_count_fn(in_file)
page_number = st.sidebar.number_input(
"Page number:",
min_value=1,
value=1,
max_value=page_count,
)
pil_image = get_page_image(in_file, page_number)
else:
pil_image = get_uploaded_image(in_file)
whole_image = st.sidebar.button("Formula Recognition")
resize_image(pil_image)
canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
with col1:
# Create a canvas component where users can draw rectangles to select formulas
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity
stroke_width=1,
stroke_color="#FFAA00",
background_color="#FFF",
background_image=pil_image,
update_streamlit=True,
height=get_image_size(pil_image)[0],
width=get_image_size(pil_image)[1],
drawing_mode="rect",
point_display_radius=0,
key=canvas_hash,
)
# Process the drawn rectangles or the whole image if 'whole_image' is True
if canvas_result.json_data is not None or whole_image:
objects = pd.json_normalize(canvas_result.json_data["objects"])
bbox_list = []
if objects.shape[0] > 0:
boxes = objects[objects["type"] == "rect"][
["left", "top", "width", "height"]
]
boxes["right"] = boxes["left"] + boxes["width"]
boxes["bottom"] = boxes["top"] + boxes["height"]
bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
if whole_image:
bbox_list = [(0, 0, pil_image.width, pil_image.height)]
if bbox_list:
with col2:
# Perform inference on each selected area and display results
inferences = [infer_image(processor, pil_image, bbox) for bbox in bbox_list]
for idx, (bbox, inference) in enumerate(zip(reversed(bbox_list), reversed(inferences))):
st.markdown(f"### Result {len(inferences) - idx}")
st.markdown(
"<h6 style='text-align: left; color: black;'>[Input: Image] </h6>",
unsafe_allow_html=True,
)
display_image_cropped(pil_image, bbox)
st.markdown(
"<h6 style='text-align: left; color: black;'>[Prediction: Rendered Image] </h6>",
unsafe_allow_html=True,
)
st.latex(inference)
st.markdown(
"<h6 style='text-align: left; color: black;'>[Prediction: LaTeX Code] </h6>",
unsafe_allow_html=True,
)
st.code(inference)
st.divider()
with col2:
tips = """
### Usage tips
- Draw a box around the equation to get the prediction."""
st.markdown(tips)
def run_mode3():
st.markdown("Coming Soon!")
if __name__ == "__main__":
st.set_page_config(layout="wide")
html_code = """
<div style='text-align: center; color: black;'>
<h2>UniMERNet Online Demo</h2>
<h5 style='text-align: left; padding-left: 20px; list-style-position: inside;'
>This App is based on <a href="https://github.com/opendatalab/UniMERNet">UniMERNet</a>. There are three optional modes for mathematical expression recognition:</h5>
<ul style='text-align: left; padding-left: 20px; list-style-position: inside;'>
<li><span style="font-weight: bold;">① Direct Recognition:</span> Input an image containing formulas and output the recognition results.</li>
<li><span style="font-weight: bold;">② Manual Selection:</span> Input a document or webpage screenshot, detect all formulas, then recognize each one.</li>
<li><span style="font-weight: bold;">③ Auto Detection:</span> Input an image or document, and the model automatically detects and recognizes all formulas.</li>
</ul>
</div>
"""
readme_text = st.markdown(html_code, unsafe_allow_html=True)
root_path = os.path.abspath(os.getcwd())
config_path = os.path.join(root_path, "configs/demo.yaml")
processor = load_image_processor(config_path)
app_mode = st.sidebar.selectbox(
"Switch Mode:", ["Direct Recognition", "Manual Selection", "Auto Detection"]
)
# Direct Recognition: Input an image containing formulas and output the recognition results.
if app_mode == "Direct Recognition":
st.markdown("---")
st.markdown(
"<h3 style='text-align: center; color: red;'> Direct Recognition </h3>",
unsafe_allow_html=True,
)
run_mode1()
# Manual Selection: Input a document or webpage screenshot, detect all formulas, then recognize each one.
elif app_mode == "Manual Selection":
st.markdown("---")
st.markdown(
"<h3 style='text-align: center; color: red;'> Manual Selection and Recognition </h3>",
unsafe_allow_html=True,
)
run_mode2()
# Auto Detection: Input an image or document, and the model automatically detects and recognizes all formulas.
elif app_mode == "Auto Detection":
st.markdown("---")
st.markdown(
"<h3 style='text-align: center; color: red;'> Auto Detection and Recognition (Coming Soon) </h3>",
unsafe_allow_html=True,
)
run_mode3()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment