Commit 73557d95 authored by yuguo960516's avatar yuguo960516
Browse files

glm

parents
Pipeline #148 failed with stages
in 0 seconds
import importlib
import time
# helper functions
def exists(val):
return val is not None
# time helpers
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol="=", repeat=40):
flank = symbol * repeat
return f"{flank} {s} {flank}"
# import helpers
def import_or_print_error(pkg_name, err_str=None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError:
if exists(err_str):
print(err_str)
exit()
# from https://github.com/lucidrains/vector_quantize_pytorch/vector_quantize_pytorch.py
import oneflow as flow
import oneflow.nn.functional as F
from einops import rearrange, repeat
from oneflow import einsum, nn
from libai.utils import distributed
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def noop(*args, **kwargs):
pass
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def log(t, eps=1e-20):
return flow.log(t.clamp(min=eps))
def uniform_init(*shape):
t = flow.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def gumbel_noise(t):
noise = flow.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
if temperature == 0:
return t.argmax(dim=dim)
return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories, eps=1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = flow.randperm(num_samples, device=device)[:num]
else:
indices = flow.randint(0, num_samples, (num,), device=device)
return samples[indices]
def batched_sample_vectors(samples, num):
return flow.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
def pad_shape(shape, size, dim=0):
return [size if i == dim else s for i, s in enumerate(shape)]
def sample_multinomial(total_count, probs):
device = probs.device
probs = probs.cpu()
total_count = probs.new_full((), total_count)
remainder = probs.new_ones(())
sample = flow.empty_like(probs, dtype=flow.long)
for i, p in enumerate(probs):
s = flow.binomial(total_count, p / remainder)
sample[i] = s
total_count -= s
remainder -= p
return sample.to(device)
def all_gather_sizes(x, dim):
size = flow.tensor(x.shape[dim], dtype=flow.long, device=x.device)
all_sizes = [flow.empty_like(size) for _ in range(distributed.get_world_size())]
distributed.all_gather(all_sizes, size)
return flow.stack(all_sizes)
def all_gather_variably_sized(x, sizes, dim=0):
rank = distributed.get_rank()
all_x = []
for i, size in enumerate(sizes):
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
distributed.broadcast(t, src=i, async_op=True)
all_x.append(t)
distributed.barrier()
return all_x
def sample_vectors_distributed(local_samples, num):
rank = distributed.get_rank()
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
else:
samples_per_rank = flow.empty_like(all_num_samples)
distributed.broadcast(samples_per_rank, src=0)
samples_per_rank = samples_per_rank.tolist()
local_samples = batched_sample_vectors(local_samples, samples_per_rank[rank])
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0)
return flow.cat(all_samples, dim=0)
def batched_bincount(x, *, minlength):
batch, dtype, device = x.shape[0], x.dtype, x.device
target = flow.zeros(batch, minlength, dtype=dtype, device=device)
values = flow.ones_like(x)
target.scatter_add_(-1, x, values)
return target
def kmeans(
samples,
num_clusters,
num_iters=10,
use_cosine_sim=False,
sample_fn=batched_sample_vectors,
all_reduce_fn=noop,
):
num_codebooks, dim, dtype, _ = (
samples.shape[0],
samples.shape[-1],
samples.dtype,
samples.device,
)
means = sample_fn(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ rearrange(means, "h n d -> h d n")
else:
dists = -flow.cdist(samples, means, p=2)
buckets = flow.argmax(dists, dim=-1)
bins = batched_bincount(buckets, minlength=num_clusters)
all_reduce_fn(bins)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
all_reduce_fn(new_means)
if use_cosine_sim:
new_means = l2norm(new_means)
means = flow.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
return means, bins
def batched_embedding(indices, embeds):
batch, dim = indices.shape[1], embeds.shape[-1]
indices = repeat(indices, "h b n -> h b n d", d=dim)
embeds = repeat(embeds, "h c d -> h b c d", b=batch)
return embeds.gather(2, indices)
# regularization losses
def orthgonal_loss_fn(t):
# eq (2) from https://arxiv.org/abs/2112.00384
h, n = t.shape[:2]
normed_codes = l2norm(t)
identity = repeat(flow.eye(n, device=t.device), "i j -> h i j", h=h)
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
return ((cosine_sim - identity) ** 2).sum() / (h * n ** 2)
# distance types
class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
num_codebooks=1,
kmeans_init=False,
kmeans_iters=10,
decay=0.8,
eps=1e-5,
threshold_ema_dead_code=2,
use_ddp=False,
learnable_codebook=False,
sample_codebook_temp=0,
):
super().__init__()
self.decay = decay
init_fn = uniform_init if not kmeans_init else flow.zeros
embed = init_fn(num_codebooks, codebook_size, dim)
self.codebook_size = codebook_size
self.num_codebooks = num_codebooks
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.sample_codebook_temp = sample_codebook_temp
self.sample_fn = sample_vectors_distributed if use_ddp else batched_sample_vectors
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", flow.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", flow.zeros(num_codebooks, codebook_size))
self.register_buffer("embed_avg", embed.clone())
self.learnable_codebook = learnable_codebook
if learnable_codebook:
self.embed = nn.Parameter(embed)
else:
self.register_buffer("embed", embed)
def init_embed_(self, data):
if self.initted:
return
embed, cluster_size = kmeans(
data,
self.codebook_size,
self.kmeans_iters,
sample_fn=self.sample_fn,
all_reduce_fn=self.all_reduce_fn,
)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(flow.Tensor([True]))
def replace(self, batch_samples, batch_mask):
batch_samples = l2norm(batch_samples)
for ind, (samples, mask) in enumerate(
zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
):
if not flow.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
self.embed.data[ind][mask] = rearrange(sampled, "1 ... -> ...")
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not flow.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
self.replace(batch_samples, batch_mask=expired_codes)
def forward(self, x):
needs_codebook_dim = x.ndim < 4
x = x.float()
if needs_codebook_dim:
x = rearrange(x, "... -> 1 ...")
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "h ... d -> h (...) d")
self.init_embed_(flatten)
embed = self.embed if not self.learnable_codebook else self.embed.detach()
dist = -flow.cdist(flatten, embed, p=2)
embed_ind = gumbel_sample(dist, dim=-1, temperature=self.sample_codebook_temp)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = embed_ind.view(*shape[:-1])
quantize = batched_embedding(embed_ind, self.embed)
if self.training:
cluster_size = embed_onehot.sum(dim=1)
self.all_reduce_fn(cluster_size)
ema_inplace(self.cluster_size, cluster_size, self.decay)
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
self.all_reduce_fn(embed_sum)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind))
return quantize, embed_ind
class CosineSimCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
num_codebooks=1,
kmeans_init=False,
kmeans_iters=10,
decay=0.8,
eps=1e-5,
threshold_ema_dead_code=2,
use_ddp=False,
learnable_codebook=False,
sample_codebook_temp=0.0,
):
super().__init__()
self.decay = decay
if not kmeans_init:
embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
else:
embed = flow.zeros(num_codebooks, codebook_size, dim)
self.codebook_size = codebook_size
self.num_codebooks = num_codebooks
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.sample_codebook_temp = sample_codebook_temp
self.sample_fn = sample_vectors_distributed if use_ddp else batched_sample_vectors
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", flow.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", flow.zeros(num_codebooks, codebook_size))
self.learnable_codebook = learnable_codebook
if learnable_codebook:
self.embed = nn.Parameter(embed)
else:
self.register_buffer("embed", embed)
def init_embed_(self, data):
if self.initted:
return
embed, cluster_size = kmeans(
data,
self.codebook_size,
self.kmeans_iters,
use_cosine_sim=True,
sample_fn=self.sample_fn,
all_reduce_fn=self.all_reduce_fn,
)
self.embed.data.copy_(embed)
self.cluster_size.data.copy_(cluster_size)
self.initted.data.copy_(flow.Tensor([True]))
def replace(self, batch_samples, batch_mask):
batch_samples = l2norm(batch_samples)
for ind, (samples, mask) in enumerate(
zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))
):
if not flow.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
self.embed.data[ind][mask] = rearrange(sampled, "1 ... -> ...")
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not flow.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
self.replace(batch_samples, batch_mask=expired_codes)
def forward(self, x):
needs_codebook_dim = x.ndim < 4
x = x.float()
if needs_codebook_dim:
x = rearrange(x, "... -> 1 ...")
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, "h ... d -> h (...) d")
flatten = l2norm(flatten)
self.init_embed_(flatten)
embed = self.embed if not self.learnable_codebook else self.embed.detach()
embed = l2norm(embed)
dist = einsum("h n d, h c d -> h n c", flatten, embed)
embed_ind = gumbel_sample(dist, dim=-1, temperature=self.sample_codebook_temp)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = embed_ind.view(*shape[:-1])
quantize = batched_embedding(embed_ind, self.embed)
if self.training:
bins = embed_onehot.sum(dim=1)
self.all_reduce_fn(bins)
ema_inplace(self.cluster_size, bins, self.decay)
zero_mask = bins == 0
bins = bins.masked_fill(zero_mask, 1.0)
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
self.all_reduce_fn(embed_sum)
embed_normalized = embed_sum / rearrange(bins, "... -> ... 1")
embed_normalized = l2norm(embed_normalized)
embed_normalized = flow.where(
rearrange(zero_mask, "... -> ... 1"), embed, embed_normalized
)
ema_inplace(self.embed, embed_normalized, self.decay)
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, "1 ... -> ..."), (quantize, embed_ind))
return quantize, embed_ind
# main class
class VectorQuantize(nn.Module):
def __init__(
self,
dim,
codebook_size,
codebook_dim=None,
heads=1,
separate_codebook_per_head=False,
decay=0.8,
eps=1e-5,
kmeans_init=False,
kmeans_iters=10,
use_cosine_sim=False,
threshold_ema_dead_code=0,
channel_last=True,
accept_image_fmap=False,
commitment_weight=1.0,
orthogonal_reg_weight=0.0,
orthogonal_reg_active_codes_only=False,
orthogonal_reg_max_codes=None,
sample_codebook_temp=0.0,
sync_codebook=False,
):
super().__init__()
self.heads = heads
self.separate_codebook_per_head = separate_codebook_per_head
codebook_dim = default(codebook_dim, dim)
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps
self.commitment_weight = commitment_weight
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
self._codebook = codebook_class(
dim=codebook_dim,
num_codebooks=heads if separate_codebook_per_head else 1,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
eps=eps,
threshold_ema_dead_code=threshold_ema_dead_code,
use_ddp=sync_codebook,
learnable_codebook=has_codebook_orthogonal_loss,
sample_codebook_temp=sample_codebook_temp,
)
self.codebook_size = codebook_size
self.accept_image_fmap = accept_image_fmap
self.channel_last = channel_last
@property
def codebook(self):
return self._codebook.embed
def forward(self, x):
_, device, heads, is_multiheaded, _ = (
x.shape,
x.device,
self.heads,
self.heads > 1,
self.codebook_size,
)
need_transpose = not self.channel_last and not self.accept_image_fmap
if self.accept_image_fmap:
height, width = x.shape[-2:]
x = rearrange(x, "b c h w -> b (h w) c")
if need_transpose:
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
if is_multiheaded:
ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d"
x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = flow.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
if self.orthogonal_reg_weight > 0:
codebook = self.codebook
if self.orthogonal_reg_active_codes_only:
# only calculate orthogonal loss for the activated codes for this batch
unique_code_ids = flow.unique(embed_ind)
codebook = codebook[unique_code_ids]
num_codes = codebook.shape[0]
if (
exists(self.orthogonal_reg_max_codes)
and num_codes > self.orthogonal_reg_max_codes
):
rand_ids = flow.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[rand_ids]
orthogonal_reg_loss = orthgonal_loss_fn(codebook)
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
if is_multiheaded:
if self.separate_codebook_per_head:
quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads)
embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads)
else:
quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads)
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
quantize = self.project_out(quantize)
if need_transpose:
quantize = rearrange(quantize, "b n d -> b d n")
if self.accept_image_fmap:
quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width)
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
return quantize, embed_ind, loss
import copy
from functools import partial, wraps
from math import sqrt
import flowvision
import oneflow as flow
import oneflow.nn.functional as F
from einops import rearrange, repeat
from oneflow import einsum, nn
from oneflow.autograd import grad as flow_grad
from libai.layers import Linear
from .einops_exts import Rearrange, rearrange_many
from .vector_quantize_flow import VectorQuantize as VQ
# constants
MList = nn.ModuleList
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# decorators
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def remove_vgg(fn):
@wraps(fn)
def inner(self, *args, **kwargs):
has_vgg = hasattr(self, "vgg")
if has_vgg:
vgg = self.vgg
delattr(self, "vgg")
out = fn(self, *args, **kwargs)
if has_vgg:
self.vgg = vgg
return out
return inner
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, string_input):
return string_input.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs
# tensor helper functions
def log(t, eps=1e-10):
return flow.log(t + eps)
def gradient_penalty(images, output, weight=10):
images.shape[0]
gradients = flow_grad(
outputs=output,
inputs=images,
grad_outputs=flow.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
def l2norm(t):
return F.normalize(t, dim=-1)
def leaky_relu(p=0.1):
return nn.LeakyReLU(0.1)
def stable_softmax(t, dim=-1, alpha=32 ** 2):
t = t / alpha
t = t - flow.amax(t, dim=dim, keepdim=True).detach()
return (t * alpha).softmax(dim=dim)
def safe_div(numer, denom, eps=1e-8):
return numer / (denom + eps)
# gan losses
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_gen_loss(fake):
return -fake.mean()
def bce_discr_loss(fake, real):
return (-log(1 - flow.sigmoid(fake)) - log(flow.sigmoid(real))).mean()
def bce_gen_loss(fake):
return -log(flow.sigmoid(fake)).mean()
def grad_layer_wrt_loss(loss, layer):
return flow_grad(
outputs=loss, inputs=layer, grad_outputs=flow.ones_like(loss), retain_graph=True
)[0].detach()
# vqgan vae
class LayerNormChan(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(flow.ones(1, dim, 1, 1))
def forward(self, x):
var = flow.var(x, dim=1, unbiased=False, keepdim=True)
mean = flow.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
# discriminator
class Discriminator(nn.Module):
def __init__(self, dims, channels=3, groups=16, init_kernel_size=5):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])
self.layers = MList(
[
nn.Sequential(
nn.Conv2d(channels, dims[0], init_kernel_size, padding=init_kernel_size // 2),
leaky_relu(),
)
]
)
for dim_in, dim_out in dim_pairs:
self.layers.append(
nn.Sequential(
nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1),
nn.GroupNorm(groups, dim_out),
leaky_relu(),
)
)
dim = dims[-1]
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
nn.Conv2d(dim, dim, 1), leaky_relu(), nn.Conv2d(dim, 1, 4)
)
def forward(self, x):
for net in self.layers:
x = net(x)
return self.to_logits(x)
# positional encoding
class ContinuousPositionBias(nn.Module):
"""from https://arxiv.org/abs/2111.09883"""
def __init__(self, *, dim, heads, layers=2):
super().__init__()
self.net = MList([])
self.net.append(nn.Sequential(Linear(2, dim), leaky_relu()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(Linear(dim, dim), leaky_relu()))
self.net.append(Linear(dim, heads))
self.register_buffer("rel_pos", None, persistent=False)
def forward(self, x):
n, device = x.shape[-1], x.device
fmap_size = int(sqrt(n))
if not exists(self.rel_pos):
pos = flow.arange(fmap_size, device=device)
grid = flow.stack(flow.meshgrid(pos, pos, indexing="ij"))
grid = rearrange(grid, "c i j -> (i j) c")
rel_pos = rearrange(grid, "i c -> i 1 c") - rearrange(grid, "j c -> 1 j c")
rel_pos = flow.sign(rel_pos) * flow.log(rel_pos.abs() + 1)
self.register_buffer("rel_pos", rel_pos, persistent=False)
rel_pos = self.rel_pos.float()
for layer in self.net:
rel_pos = layer(rel_pos)
bias = rearrange(rel_pos, "i j h -> h i j")
return x + bias
# resnet encoder / decoder
class ResnetEncDec(nn.Module):
def __init__(
self,
dim,
*,
channels=3,
layers=4,
layer_mults=None,
num_resnet_blocks=1,
resnet_groups=16,
first_conv_kernel_size=5,
use_attn=True,
attn_dim_head=64,
attn_heads=8,
attn_dropout=0.0,
):
super().__init__()
assert (
dim % resnet_groups == 0
), f"dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)"
self.layers = layers
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert (
len(layer_mults) == layers
), "layer multipliers must be equal to designated number of layers"
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
def append(arr, t):
arr.append(t)
def prepend(arr, t):
arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert (
len(num_resnet_blocks) == layers
), "number of resnet blocks config must be equal to number of layers"
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(
range(layers), dim_pairs, num_resnet_blocks, use_attn
):
append(
self.encoders,
nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), leaky_relu()),
)
prepend(
self.decoders,
nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()),
)
if layer_use_attn:
prepend(
self.decoders,
VQGanAttention(
dim=dim_out, heads=attn_heads, dim_head=attn_dim_head, dropout=attn_dropout
),
)
for _ in range(layer_num_resnet_blocks):
append(self.encoders, ResBlock(dim_out, groups=resnet_groups))
prepend(self.decoders, GLUResBlock(dim_out, groups=resnet_groups))
if layer_use_attn:
append(
self.encoders,
VQGanAttention(
dim=dim_out, heads=attn_heads, dim_head=attn_dim_head, dropout=attn_dropout
),
)
prepend(
self.encoders,
nn.Conv2d(channels, dim, first_conv_kernel_size, padding=first_conv_kernel_size // 2),
)
append(self.decoders, nn.Conv2d(dim, channels, 1))
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
return x
def decode(self, x):
for dec in self.decoders:
x = dec(x)
return x
class GLUResBlock(nn.Module):
def __init__(self, chan, groups=16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan * 2, 3, padding=1),
nn.GLU(dim=1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan * 2, 3, padding=1),
nn.GLU(dim=1),
nn.GroupNorm(groups, chan),
nn.Conv2d(chan, chan, 1),
)
def forward(self, x):
return self.net(x) + x
class ResBlock(nn.Module):
def __init__(self, chan, groups=16):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(chan, chan, 3, padding=1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 3, padding=1),
nn.GroupNorm(groups, chan),
leaky_relu(),
nn.Conv2d(chan, chan, 1),
)
def forward(self, x):
return self.net(x) + x
# vqgan attention layer
class VQGanAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8, dropout=0.0):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.dropout = nn.Dropout(dropout)
self.pre_norm = LayerNormChan(dim)
self.cpb = ContinuousPositionBias(dim=dim // 4, heads=heads)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False)
def forward(self, x):
h = self.heads
height, width, residual = *x.shape[-2:], x.clone()
x = self.pre_norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=h), (q, k, v))
sim = einsum("b h c i, b h c j -> b h i j", q, k) * self.scale
sim = self.cpb(sim)
attn = stable_softmax(sim, dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h c j -> b h c i", attn, v)
out = rearrange(out, "b h c (x y) -> b (h c) x y", x=height, y=width)
out = self.to_out(out)
return out + residual
# ViT encoder / decoder
class RearrangeImage(nn.Module):
def forward(self, x):
n = x.shape[1]
w = h = int(sqrt(n))
return rearrange(x, "b (h w) ... -> b h w ...", h=h, w=w)
class Attention(nn.Module):
def __init__(self, dim, *, heads=8, dim_head=32):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = einsum("b h i d, b h j d -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
def FeedForward(dim, mult=4):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, bias=False),
nn.GELU(),
nn.Linear(dim * mult, dim, bias=False),
)
class Transformer(nn.Module):
def __init__(self, dim, *, layers, dim_head=32, heads=8, ff_mult=4):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(layers):
self.layers.append(
nn.ModuleList(
[
Attention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class ViTEncDec(nn.Module):
def __init__(self, dim, channels=3, layers=4, patch_size=8, dim_head=32, heads=8, ff_mult=4):
super().__init__()
self.encoded_dim = dim
self.patch_size = patch_size
input_dim = channels * (patch_size ** 2)
self.encoder = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size),
Linear(input_dim, dim),
Transformer(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, layers=layers),
RearrangeImage(),
Rearrange("b h w c -> b c h w"),
)
self.decoder = nn.Sequential(
Rearrange("b c h w -> b (h w) c"),
Transformer(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, layers=layers),
nn.Sequential(
Linear(dim, dim * 4, bias=False),
nn.Tanh(),
Linear(dim * 4, input_dim, bias=False),
),
RearrangeImage(),
Rearrange("b h w (p1 p2 c) -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size),
)
def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size
@property
def last_dec_layer(self):
return self.decoder[-3][-1].weight
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
# main vqgan-vae classes
class NullVQGanVAE(nn.Module):
def __init__(self, *, channels):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module):
def __init__(
self,
*,
dim,
image_size,
channels=3,
layers=4,
l2_recon_loss=False,
use_hinge_loss=True,
vgg=None,
vq_codebook_dim=256,
vq_codebook_size=512,
vq_decay=0.8,
vq_commitment_weight=1.0,
vq_kmeans_init=True,
vq_use_cosine_sim=True,
use_vgg_and_gan=True,
vae_type="resnet",
discr_layers=4,
**kwargs,
):
super().__init__()
vq_kwargs, kwargs = groupby_prefix_and_trim("vq_", kwargs)
encdec_kwargs, kwargs = groupby_prefix_and_trim("encdec_", kwargs)
self.image_size = image_size
self.channels = channels
self.codebook_size = vq_codebook_size
if vae_type == "resnet":
enc_dec_klass = ResnetEncDec
elif vae_type == "vit":
enc_dec_klass = ViTEncDec
else:
raise ValueError(f"{vae_type} not valid")
self.enc_dec = enc_dec_klass(dim=dim, channels=channels, layers=layers, **encdec_kwargs)
self.vq = VQ(
dim=self.enc_dec.encoded_dim,
codebook_dim=vq_codebook_dim,
codebook_size=vq_codebook_size,
decay=vq_decay,
commitment_weight=vq_commitment_weight,
accept_image_fmap=True,
kmeans_init=vq_kmeans_init,
use_cosine_sim=vq_use_cosine_sim,
**vq_kwargs,
)
# reconstruction loss
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
# turn off GAN and perceptual loss if grayscale
self.vgg = None
self.discr = None
self.use_vgg_and_gan = use_vgg_and_gan
if not use_vgg_and_gan:
return
# preceptual loss
if exists(vgg):
self.vgg = vgg
else:
self.vgg = flowvision.models.vgg16(pretrained=True)
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
# gan related losses
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.discr = Discriminator(dims=dims, channels=channels)
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
@property
def encoded_dim(self):
return self.enc_dec.encoded_dim
def get_encoded_fmap_size(self, image_size):
return self.enc_dec.get_encoded_fmap_size(image_size)
def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
if vae_copy.use_vgg_and_gan:
del vae_copy.discr
del vae_copy.vgg
vae_copy.eval()
return vae_copy.to(device)
@remove_vgg
def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)
@remove_vgg
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)
@property
def codebook(self):
return self.vq.codebook
def encode(self, fmap):
fmap = self.enc_dec.encode(fmap)
return fmap
def decode(self, fmap, return_indices_and_loss=False):
fmap, indices, commit_loss = self.vq(fmap)
fmap = self.enc_dec.decode(fmap)
if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
def forward(
self,
img,
return_loss=False,
return_discr_loss=False,
return_recons=False,
add_gradient_penalty=True,
):
_, channels, height, width, _ = *img.shape, img.device
assert (
height == self.image_size and width == self.image_size
), "height and width of input image must be equal to {self.image_size}"
assert (
channels == self.channels
), "number of channels on image or sketch is not equal to the channels set on this VQGanVAE"
fmap = self.encode(img)
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss=True)
if not return_loss and not return_discr_loss:
return fmap
assert (
return_loss ^ return_discr_loss
), "you should either return autoencoder loss or discriminator loss, but not both"
# whether to return discriminator loss
if return_discr_loss:
assert exists(self.discr), "discriminator must exist to train it"
fmap.detach_()
img.requires_grad_()
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap
return loss
# reconstruction loss
recon_loss = self.recon_loss_fn(fmap, img)
# early return if training on grayscale
if not self.use_vgg_and_gan:
if return_recons:
return recon_loss, fmap
return recon_loss
# perceptual loss
img_vgg_input = img
fmap_vgg_input = fmap
if img.shape[1] == 1:
# handle grayscale for vgg
img_vgg_input, fmap_vgg_input = map(
lambda t: repeat(t, "b 1 ... -> b c ...", c=3), (img_vgg_input, fmap_vgg_input)
)
img_vgg_feats = self.vgg(img_vgg_input)
recon_vgg_feats = self.vgg(fmap_vgg_input)
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
# generator loss
gen_loss = self.gen_loss(self.discr(fmap))
# calculate adaptive weight
last_dec_layer = self.enc_dec.last_dec_layer
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(
p=2
)
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
adaptive_weight.clamp_(max=1e4)
# combine losses
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
if return_recons:
return loss, fmap
return loss
import os
from typing import Dict
import oneflow as flow
from dalle2.dalle2_loader import Dalle2ModelLoader
from dalle2.model_weights.download_utils import download_dalle2_weights
from dalle2.tokenizer import SimpleTokenizer
from oneflow.framework import balanced_splitter
import libai.utils.distributed as dist
from libai.inference.basic import BasePipeline
class Dalle2Pipeline(BasePipeline):
def __init__(
self,
config_file,
data_parallel=None,
tensor_parallel=None,
pipeline_parallel=None,
pipeline_stage_id=None,
pipeline_num_layers=None,
model_path=None,
mode="libai",
**kwargs,
):
super().__init__(
config_file,
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
model_path,
pipeline_num_layers,
mode,
**kwargs,
)
def update_cfg(
self,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_stage_id=None,
pipeline_num_layers=None,
):
super().update_cfg(
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
pipeline_num_layers,
)
self.cfg.model.prior.clip.name = "./dalle2/model_weights/ViT-L-14.pt"
self.cfg.model.prior_weight_path = "./dalle2/model_weights/prior_aes_finetune.pth"
self.cfg.model.decoder_weight_path = "./dalle2/model_weights/latest.pth"
self.cfg.swinir.swinir_path = (
"./swinir/weights/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
)
def load_pretrain_weight(self, libai_cfg_model, model_path, mode=None):
if dist.is_main_process():
download_dalle2_weights(self.cfg)
dist.synchronize()
model_loader = Dalle2ModelLoader(libai_cfg_model, self.cfg, model_path)
return model_loader.load()
def build_tokenizer(self, cfg):
return SimpleTokenizer() # return instantiate(cfg.tokenizer)
def _parse_parameters(self, model_path=None, save_images=False, upsample_scale=None, **kwargs):
preprocess_params = {}
forward_params = {
"model_path": model_path,
"num_samples_per_batch": kwargs.get("num_samples_per_batch", 2),
"prior_cond_scale": kwargs.get("prior_cond_scale", 1.0),
"decoder_cond_scale": kwargs.get("decoder_cond_scale", 3.5),
}
postprocess_params = {
"save_images": save_images,
"upsample_scale": upsample_scale,
"swinir_path": "./swinir/weights/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth",
}
return preprocess_params, forward_params, postprocess_params
def split_data(self, text):
rank = dist.get_rank()
indices = balanced_splitter.BalancedRanges(len(text), dist.get_world_size())
return text[indices[rank][0] : indices[rank][1]]
def preprocess(self, input_, **preprocess_parameters: Dict) -> dict:
tokens = self.tokenizer.tokenize(input_).to_global(
placement=flow.placement(type="cuda", ranks=list(range(dist.get_world_size()))),
sbp=flow.sbp.broadcast,
)
return {"text": input_, "tokens": tokens}
def forward(self, model_input_dict, **forward_params) -> dict:
tokens = model_input_dict["tokens"]
text_embed, text_encodings, text_mask = self.model.prior.clip.embed_text(tokens)
image_embed = self.model.prior.sample(
tokens,
num_samples_per_batch=forward_params["num_samples_per_batch"],
cond_scale=forward_params["prior_cond_scale"],
)
image_embed = self.model.decoder.sample(
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
cond_scale=forward_params["decoder_cond_scale"],
)
return {"image_embed": image_embed}
def postprocess(self, model_output_dict, **postprocess_params: Dict) -> dict:
if not postprocess_params.get("save_images", False):
return model_output_dict
output_path = postprocess_params.get("output_dit", "./outputs")
os.makedirs(output_path, exist_ok=True)
import flowvision.transforms as T
to_pil = T.ToPILImage()
images = model_output_dict["image_embed"].to("cpu")
images_64x64 = list(map(to_pil, [images[i] for i in range(images.shape[0])]))
for i, image in enumerate(images_64x64):
image.save(f"{output_path}/{i}.png")
if postprocess_params.get("upsample_scale", False):
from swinir import load_model, upsample4x, upsample16x
swinir = load_model(postprocess_params.get("swinir_path", ""))
upsample_fun = upsample4x if args.upsample_scale == 4 else upsample16x
images = upsample_fun(images, swinir).to("cpu")
images = list(map(to_pil, [images[i] for i in range(images.shape[0])]))
for i, image in enumerate(images):
image.save(f"{output_path}/{i}_{args.upsample_scale}x.png")
print(f"Images have been saved under {output_path}.")
return model_output_dict
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_file", type=str, default="configs/dalle2_config.py")
parser.add_argument("--data_parallel", type=int, default=1)
parser.add_argument("--tensor_parallel", type=int, default=4)
parser.add_argument("--pipeline_parallel", type=int, default=1)
parser.add_argument(
"--upsample_scale",
type=int,
choices=[4, 16],
default=None,
help="upsample scale, if 4x, output resolution will be 256 x 256.",
)
parser.add_argument(
"--swinir_path",
type=str,
default="./swinir/weights/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth",
)
parser.add_argument("--output_dir", type=str, default="./outputs")
parser.add_argument("--save_images", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
model = Dalle2Pipeline(
config_file=args.config_file,
data_parallel=args.data_parallel,
tensor_parallel=args.tensor_parallel,
pipeline_parallel=args.pipeline_parallel,
)
texts = [
"a shiba inu wearing a beret and black turtleneck",
"a teddy bear on a skateboard in times square",
"trump fight with biden in white house",
"Donald trump fight with biden in white house",
]
imgs = model(texts, **vars(args))
# DALLE2
This project is adapted from [dalle2_pytorch](https://github.com/lucidrains/DALLE2-pytorch); And dalle2_pytorch version=0.15.4 is used following this [colab](https://colab.research.google.com/github/LAION-AI/dalle2-laion/blob/main/notebooks/dalle2_laion_alpha.ipynb).
This project aims at guiding how to transfer pytorch models to oneflow and use distributed inference for new users with [LiBai](https://github.com/Oneflow-Inc/libai), details could be found [here](../../docs/source/notes/How_to_use_model_parallel_in_LiBai.md).
## How to run this project
```sh
cd libai/projects/DALLE2
pip install -r requirements.txt
python3 -m oneflow.distributed.launch \
--nproc_per_node 4 \
dalle2_inference.py \
--save_images \
--output_dir ./outputs \
--upsample_scale 4
```
`--nprec_per_node 4` means this model will be executed on 4 gpus under the model parallel mode.
The output images will be saved to `--output_dir` by setting `--save_images`. The resolution of the generated images are 64x64 by default, and could be resize to 256x256 with `--upsample_scale 4` (and 1024x1024 with `--upsample_scale 16`) by using [SwinIR](https://github.com/JingyunLiang/SwinIR).
At the bottom of the dalle2_inference.py, try feeding different text and see what the model will generated.
\ No newline at end of file
ftfy
resize_right
einops
kornia
from .models import SwinIR
from .upsample import load_model, upsample4x, upsample16x
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
# code from https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py
import math
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
from oneflow.utils import checkpoint
from .utils import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional):
If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
flow.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = flow.arange(self.window_size[0])
coords_w = flow.arange(self.window_size[1])
coords = flow.stack(flow.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = flow.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = flow.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images
# whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = flow.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, "
f"num_heads={self.num_heads}, window_size={self.window_size},"
f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
)
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = flow.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional):
Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional):
If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional):
Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional):
Downsample layer at the end of the layer. Default: None
use_checkpoint (bool):
Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection="1conv",
):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint,
)
if resi_connection == "1conv":
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == "3conv":
# to save parameters and memory
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1),
)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None
)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None
)
def forward(self, x, x_size):
return (
self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)))
+ x
)
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r"""Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f"scale {scale} is not supported. " "Supported scales: 2^n and 3.")
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module
(the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r"""SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`,
based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float):
Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool):
If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool):
Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor.
2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module.
'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(
self,
img_size=64,
patch_size=1,
in_chans=3,
embed_dim=96,
depths=[6, 6, 6, 6],
num_heads=[6, 6, 6, 6],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
upscale=2,
img_range=1.0,
upsampler="",
resi_connection="1conv",
**kwargs,
):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = flow.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = flow.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
# 1, shallow feature extraction
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
# 2, deep feature extraction
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(flow.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[
sum(depths[:i_layer]) : sum(depths[: i_layer + 1])
], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == "1conv":
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == "3conv":
# to save parameters and memory
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
)
# 3, high quality image reconstruction
if self.upsampler == "pixelshuffle":
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == "pixelshuffledirect":
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(
upscale, embed_dim, num_out_ch, (patches_resolution[0], patches_resolution[1])
)
elif self.upsampler == "nearest+conv":
# for real-world SR (less artifacts)
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def no_weight_decay(self):
return {"absolute_pos_embed"}
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x).to(x.device)
x = (x - self.mean) * self.img_range
if self.upsampler == "pixelshuffle":
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == "pixelshuffledirect":
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == "nearest+conv":
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(flow.nn.functional.interpolate(x, scale_factor=2, mode="nearest"))
)
if self.upscale == 4:
x = self.lrelu(
self.conv_up2(flow.nn.functional.interpolate(x, scale_factor=2, mode="nearest"))
)
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, : H * self.upscale, : W * self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == "__main__":
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(
upscale=2,
img_size=(height, width),
window_size=window_size,
img_range=1.0,
depths=[6, 6, 6, 6],
embed_dim=60,
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
upsampler="pixelshuffledirect",
)
print(model)
print(height, width, model.flops() / 1e9)
x = flow.randn((1, 3, height, width))
x = model(x)
print(x.shape)
import os
import oneflow as flow
import requests
from .models import SwinIR as net
def load_torch_weight(model, model_path):
# load torch weight
import torch
param_key_g = "params_ema"
pretrained_model = torch.load(model_path, map_location="cpu")
pretrained_model = (
pretrained_model[param_key_g]
if param_key_g in pretrained_model.keys()
else pretrained_model
)
new_state_dict = {}
for k, v in pretrained_model.items():
flow_tensor = flow.tensor(v.numpy())
new_state_dict[k] = flow_tensor
model.load_state_dict(new_state_dict, strict=True)
return model
def load_model(model_path=None):
# set up model
if os.path.exists(model_path):
print(f"loading model from {model_path}")
else:
os.makedirs(os.path.dirname(model_path), exist_ok=True)
url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}".format(
os.path.basename(model_path)
)
r = requests.get(url, allow_redirects=True)
print(f"downloading model {model_path}")
open(model_path, "wb").write(r.content)
model = net(
upscale=4,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
model = load_torch_weight(model, model_path)
return model
def upsample4x(img_lq, model):
"""upsample img from h*w to (4h) * (4w)"""
device = flow.device("cuda" if flow.cuda.is_available() else "cpu")
model.eval()
model = model.to(device)
img_lq = img_lq.to(device)
window_size = 8
scale = 4
# inference
with flow.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_lq.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = flow.cat([img_lq, flow.flip(img_lq, [2])], 2)[:, :, : h_old + h_pad, :]
img_lq = flow.cat([img_lq, flow.flip(img_lq, [3])], 3)[:, :, :, : w_old + w_pad]
output = model(img_lq)
output = output[..., : h_old * scale, : w_old * scale]
output = output.clamp_(0, 1)
return output
def upsample16x(imgs, model):
return upsample4x(upsample4x(imgs, model), model)
# -----------------------------------------------------------------------------------
# from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# -----------------------------------------------------------------------------------
import collections.abc
import math
import warnings
from itertools import repeat
import oneflow as flow
import oneflow.nn as nn
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from Pytorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with flow.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (flow.Tensor, float, float, float, float) -> flow.Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the Pytorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# -----------------------------------------------------------------------------------
# from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
# -----------------------------------------------------------------------------------
# From Pytorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
# -----------------------------------------------------------------------------------
# from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
# -----------------------------------------------------------------------------------
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a
separate paper...
See discussion: https://github.com/tensortorch/tpu/issues/494#issuecomment-532968956 ...
I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect
as a layer name and use 'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
from omegaconf import DictConfig
from libai.config import LazyCall
from projects.GLM.modeling_glm import GLMModel
cfg = dict(
num_layers=48,
vocab_size=30592,
hidden_size=4096,
num_attention_heads=64,
max_sequence_length=1024,
embedding_dropout_prob=0.1,
attention_dropout_prob=0.1,
output_dropout_prob=0.1,
layernorm_epsilon=1e-5,
initializer_range=0.02,
use_scaled_init_for_output_weights=True,
bias_gelu_fusion=True,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
amp_enabled=False,
block_position_encoding=True,
attention_scale=1.0,
padding_idx=None,
# Inference
is_encoder_decoder=False,
max_length=512,
min_length=0,
do_sample=False,
early_stopping=False,
num_beams=1,
num_beam_groups=1,
diversity_penalty=0.0,
temperature=1.0,
top_k=50,
top_p=1.0,
typical_p=1.0,
repetition_penalty=1.0,
length_penalty=1.0,
no_repeat_ngram_size=0,
encoder_no_repeat_ngram_size=0,
num_return_sequences=1,
chunk_size_feed_forward=0,
output_scores=False,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=False,
exponential_decay_length_penalty=None,
use_cache=False,
# Tokenizer
pad_token_id=50000,
eos_token_id=50007,
bos_token_id=None,
sep_token_id=None,
decoder_start_token_id=None,
)
cfg = DictConfig(cfg)
glm_model = LazyCall(GLMModel)(cfg=cfg)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow as flow
from oneflow import nn
from libai.layers.linear import Linear
class MultiheadAttention(nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
attention_dropout_prob=0.0,
output_dropout_prob=0.0,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
attention_scale=1.0,
*,
layer_idx=0
):
super().__init__()
self.hidden_size = hidden_size
self.attention_scale = attention_scale
if output_layer_init_method is None:
output_layer_init_method = init_method
assert (
hidden_size % num_attention_heads == 0
), "hidden_size must be divisible by num_attention_heads."
self.num_heads = num_attention_heads
self.head_size = hidden_size // num_attention_heads
self.attention_dropout_prob = attention_dropout_prob
self.dropout = nn.Dropout(p=attention_dropout_prob)
self.norm_factor = 1.0 / math.sqrt(float(self.head_size))
self.coeff = None
if apply_query_key_layer_scaling:
self.coeff = layer_idx + 1
self.norm_factor /= self.coeff
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion
self.bias_dropout_fusion = bias_dropout_fusion
if self.bias_dropout_fusion:
self.output_dropout_prob = output_dropout_prob
else:
self.output_dropout = nn.Dropout(p=output_dropout_prob)
self.query_key_value = Linear(
self.hidden_size,
self.hidden_size * 3,
parallel="col",
init_method=init_method,
layer_idx=layer_idx,
)
self.dense = Linear(
self.hidden_size,
self.hidden_size,
parallel="row",
init_method=output_layer_init_method,
skip_bias_add=self.bias_dropout_fusion,
layer_idx=layer_idx,
)
def forward(
self,
hidden_states: flow.Tensor,
attention_mask: flow.Tensor = None,
mem=None,
):
attention_mask = (
attention_mask.to_global(placement=hidden_states.placement)
if attention_mask is not None
else None
)
bsz, tgt_len = hidden_states.size()[:2]
if mem is not None:
hidden_states = flow.cat((mem, hidden_states), dim=1)
query_key_value = self.query_key_value(hidden_states)
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
query_key_value = query_key_value.permute(0, 2, 1, 3)
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
if mem is not None:
query = query[:, :, -tgt_len:]
if self.attention_scale > 1.0:
attention_scores = flow.matmul(
query / math.sqrt(self.attention_scale),
key / math.sqrt(self.head_size * self.attention_scale),
transpose_b=True,
)
else:
attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor)
if self.scale_mask_softmax_fusion:
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=self.coeff,
p=self.attention_dropout_prob,
)[0]
else:
if self.coeff is not None:
attention_scores *= self.coeff
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
attention_weights = flow.softmax(attention_scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = flow.matmul(attention_weights, value)
context = context.transpose(1, 2)
output = self.dense(context.flatten(2))
if self.bias_dropout_fusion:
output, bias = output
output = flow._C.fused_bias_add_dropout(
output, bias, p=self.output_dropout_prob, axis=output.ndim - 1
)
else:
output = self.output_dropout(output)
return output
def extra_repr(self) -> str:
return "hidden_size={}, num_heads={}".format(
self.hidden_size,
self.num_heads,
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
import libai.utils.distributed as dist
from libai.layers import Embedding, VocabEmbedding
from libai.models.utils import init_method_normal
class GLMEmbedding(nn.Module):
def __init__(
self,
vocab_size,
hidden_size,
max_seq_length,
padding_idx=None,
init_method=init_method_normal(0.02, 0),
embedding_dropout_prob=0.0,
amp_enabled=False,
block_position_encoding=False,
):
super().__init__()
self.block_position_encoding = block_position_encoding
self.word_embeddings = VocabEmbedding(
vocab_size,
hidden_size,
padding_idx=padding_idx,
init_method=init_method,
amp_enabled=amp_enabled,
)
if block_position_encoding:
self.position_embeddings = Embedding(
max_seq_length + 1, hidden_size, init_method=init_method, amp_enabled=amp_enabled
)
self.block_position_embeddings = Embedding(
max_seq_length + 1, hidden_size, init_method=init_method, amp_enabled=amp_enabled
)
self.embedding_dropout = nn.Dropout(embedding_dropout_prob)
self.position_ids = flow.arange(
max_seq_length,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
).unsqueeze(0)
self.block_position_ids = flow.zeros(
(1, max_seq_length),
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
def forward(self, input_ids, position_ids=None):
bsz, seq_len = input_ids.size()
if self.block_position_encoding and position_ids is not None:
position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_len]
position_ids = position_ids.expand_as(input_ids).to_global(sbp=input_ids.sbp)
block_position_ids = self.block_position_ids[:, :seq_len]
block_position_ids = block_position_ids.expand_as(input_ids).to_global(
sbp=input_ids.sbp
)
word_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
input_embeddings = word_embeddings + position_embeddings
if self.block_position_encoding:
block_position_embeddings = self.block_position_embeddings(block_position_ids)
input_embeddings = input_embeddings + block_position_embeddings
input_embeddings = self.embedding_dropout(input_embeddings)
return input_embeddings
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow as flow
from oneflow import nn
import libai.utils.distributed as dist
class SinePositionalEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
position_embedding = flow.zeros(
num_embeddings,
embedding_dim,
dtype=flow.float32,
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
position = flow._C.global_arange(
start=0,
end=num_embeddings,
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
dtype=flow.float32,
).unsqueeze(1)
position_range = flow._C.global_arange(
start=0,
end=embedding_dim,
step=2,
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
dtype=flow.float32,
)
div_term = flow.exp(position_range * (-math.log(10000.0) / embedding_dim))
position_embedding[:, : embedding_dim // 2] = flow.sin(position * div_term)
position_embedding[:, embedding_dim // 2 :] = flow.cos(position * div_term)
self.register_buffer("position_embedding", position_embedding)
def forward(self, position_ids):
position_embeds = flow._C.gather(self.position_embedding, position_ids, axis=0)
return position_embeds
def extra_repr(self) -> str:
s = "num_embeddings={num_embeddings}, embedding_dim={embedding_dim}"
return s.format(**self.__dict__)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow.nn as nn
from libai.layers.layer_norm import LayerNorm
from libai.layers.mlp import MLP
from libai.utils import distributed as dist
from projects.GLM.layers.attention_layer import MultiheadAttention
class TransformerLayer(nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
attention_dropout_prob=0.0,
output_dropout_prob=0.0,
layernorm_epsilon=1e-5,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
attention_scale=1.0,
*,
layer_idx=0
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attention_dropout_prob = attention_dropout_prob
self.output_dropout_prob = output_dropout_prob
self.layernorm_epsilon = layernorm_epsilon
self.attention_scale = attention_scale
self.layer_idx = layer_idx
self.bias_gelu_fusion = bias_gelu_fusion
self.bias_dropout_fusion = bias_dropout_fusion
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.init_method = init_method
if output_layer_init_method is None:
output_layer_init_method = init_method
self.output_layer_init_method = output_layer_init_method
self.input_layernorm = LayerNorm(
self.hidden_size, eps=self.layernorm_epsilon, layer_idx=self.layer_idx
)
self.attention = self.build_attention()
self.post_attention_layernorm = LayerNorm(
self.hidden_size, eps=self.layernorm_epsilon, layer_idx=self.layer_idx
)
self.mlp = MLP(
self.hidden_size,
4 * self.hidden_size,
self.output_dropout_prob,
self.init_method,
output_layer_init_method=self.output_layer_init_method,
bias_gelu_fusion=self.bias_gelu_fusion,
bias_dropout_fusion=self.bias_dropout_fusion,
layer_idx=self.layer_idx,
)
def forward(
self,
hidden_states,
attention_mask,
mem=None,
):
hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))
attention_mask = (
attention_mask.to_global(placement=dist.get_layer_placement(self.layer_idx))
if attention_mask is not None
else None
)
mem = (
mem.to_global(placement=dist.get_layer_placement(self.layer_idx))
if mem is not None
else None
)
layernorm_output = self.input_layernorm(hidden_states)
mem = self.input_layernorm(mem) if mem is not None else None
attention_output = self.attention(
layernorm_output,
attention_mask=attention_mask,
mem=mem,
)
hidden_states = hidden_states + attention_output
layernorm_output = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(layernorm_output)
output = hidden_states + mlp_output
return output
def build_attention(self):
return MultiheadAttention(
self.hidden_size,
self.num_attention_heads,
attention_dropout_prob=self.attention_dropout_prob,
output_dropout_prob=self.output_dropout_prob,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
bias_dropout_fusion=self.bias_dropout_fusion,
scale_mask_softmax_fusion=self.scale_mask_softmax_fusion,
apply_query_key_layer_scaling=self.apply_query_key_layer_scaling,
attention_scale=self.attention_scale,
layer_idx=self.layer_idx,
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn.functional as F
from oneflow import nn
import libai.utils.distributed as dist
from libai.config import configurable
from libai.inference.generator.generation_utils import Generator
from libai.layers import LayerNorm, LMLogits, ParallelCrossEntropyLoss
from libai.models.utils import init_method_normal, scaled_init_method_normal
from projects.GLM.layers.embedding_layer import GLMEmbedding
from projects.GLM.layers.transformer_layer import TransformerLayer
class Transformer(nn.Module):
def __init__(
self,
num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob=0.0,
output_dropout_prob=0.0,
layernorm_epsilon=1.0e-5,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
attention_scale=1.0,
):
super().__init__()
self.num_layers = num_layers
def build_layer(layer_number):
return TransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_scale=attention_scale,
layer_idx=layer_number,
)
self.layers = nn.ModuleList([build_layer(i) for i in range(self.num_layers)])
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, layer_idx=-1)
def forward(self, hidden_states, attention_mask, memory_states=None):
mem_layers = [hidden_states.detach()]
for i, layer in enumerate(self.layers):
mem_i = memory_states[i] if memory_states is not None else None
hidden_states = layer(hidden_states, attention_mask, mem=mem_i)
mem_layers.append(hidden_states.detach())
output = self.final_layernorm(hidden_states)
return output, mem_layers
class GLMModel(nn.Module):
@configurable
def __init__(
self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length=1024,
embedding_dropout_prob=0.0,
attention_dropout_prob=0.0,
output_dropout_prob=0.0,
layernorm_epsilon=1e-5,
initializer_range=0.02,
use_scaled_init_for_output_weights=True,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
amp_enabled=False,
block_position_encoding=False,
attention_scale=1.0,
padding_idx=None,
):
super().__init__()
init_method = init_method_normal(sigma=initializer_range, mean=0)
if use_scaled_init_for_output_weights:
output_layer_init_method = scaled_init_method_normal(initializer_range, num_layers)
else:
output_layer_init_method = init_method
self.embeddings = GLMEmbedding(
vocab_size,
hidden_size,
max_sequence_length,
padding_idx=padding_idx,
init_method=init_method,
embedding_dropout_prob=embedding_dropout_prob,
amp_enabled=amp_enabled,
block_position_encoding=block_position_encoding,
)
self.transformer = Transformer(
num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_scale=attention_scale,
)
self.lm_head = LMLogits(vocab_size, bias=False)
@classmethod
def from_config(cls, cfg):
return {
"num_layers": cfg.num_layers,
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"num_attention_heads": cfg.num_attention_heads,
"max_sequence_length": cfg.max_sequence_length,
"embedding_dropout_prob": cfg.embedding_dropout_prob,
"attention_dropout_prob": cfg.attention_dropout_prob,
"output_dropout_prob": cfg.output_dropout_prob,
"layernorm_epsilon": cfg.layernorm_epsilon,
"initializer_range": cfg.initializer_range,
"use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights,
"bias_gelu_fusion": cfg.bias_gelu_fusion,
"bias_dropout_fusion": cfg.bias_dropout_fusion,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
"apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling,
"amp_enabled": cfg.amp_enabled,
"block_position_encoding": cfg.block_position_encoding,
"attention_scale": cfg.attention_scale,
"padding_idx": cfg.padding_idx,
}
def forward(
self,
input_ids,
position_ids=None,
attention_mask=None,
memory_states=None,
output_predict=True,
):
input_ids = input_ids.to_global(placement=dist.get_layer_placement(0))
position_ids = (
position_ids.to_global(placement=dist.get_layer_placement(0))
if position_ids is not None
else None
)
attention_mask = (
attention_mask.to_global(placement=dist.get_layer_placement(0))
if attention_mask is not None
else None
)
batch_size, query_length = input_ids.size()
memory_length = memory_states[0].size(1) if memory_states is not None else 0
is_scalar = flow.numel(attention_mask) == 1
is_sep = is_scalar or flow.numel(attention_mask) == batch_size
if is_sep:
sep = attention_mask.item() if is_scalar else attention_mask
attention_mask = self.build_mask_matrix(
batch_size, query_length, sep, memory_length=memory_length, is_scalar=is_scalar
)
else:
if attention_mask.dim() == 2:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
attention_mask = attention_mask[:, :, :, -query_length - memory_length :]
input_embeds = self.embeddings(input_ids, position_ids)
logits, mem_layers = self.transformer(
input_embeds, attention_mask=attention_mask, memory_states=memory_states
)
mem_layers = self.update_mems(mem_layers, memory_states)
if output_predict:
logits = self.lm_head(logits, self.embeddings.word_embeddings.weight)
return (logits, mem_layers)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
# Old API in OneFlow 0.8
if hasattr(module_block, "origin"):
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).activation_checkpointing = True
def build_mask_matrix(self, batch_size, seq_length, sep, memory_length=0, is_scalar=False):
m = flow.tril(
flow.ones((1, seq_length, seq_length)),
)
if is_scalar:
m[0, :, : int(sep)] = 1
else:
m = m.expand(batch_size, -1, -1)
ids = flow.arange(seq_length, device=sep.device, dtype=sep.dtype).view(1, -1)
mask = ids < sep.view(-1, 1)
m = m.masked_fill(mask.unsqueeze(1).expand_as(m), 1)
if memory_length > 0:
m = m.expand(batch_size, -1, -1)
m = flow.cat((flow.ones((batch_size, seq_length, memory_length)), m), dim=2)
m = m.unsqueeze(1)
m = m.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
return m
def update_mems(self, hiddens, mems):
memory_length = mems[0].size(1) if mems is not None else 0
query_length = hiddens[0].size(1)
new_memory_length = memory_length + query_length
new_mems = []
for i in range(len(hiddens)):
if new_memory_length <= query_length:
new_mems.append(hiddens[i][:, -new_memory_length:])
else:
new_mems.append(
flow.cat((mems[i][:, -new_memory_length + query_length :], hiddens[i]), dim=1)
)
return new_mems
class GLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_func = ParallelCrossEntropyLoss()
def forward(self, logits, labels):
lm_loss = self.loss_func(logits, labels)
lm_loss = lm_loss.mean()
return {"lm_loss": lm_loss}
class GLMForMultipleChoice(nn.Module):
def __init__(self, cfg):
super().__init__()
self.glm = GLMModel(cfg)
self.loss_func = GLMLoss()
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
choice_ids=None,
choice_indices=None,
labels=None,
mems=None,
**kwargs,
):
lm_logits, mem_layers = self.glm(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
memory_states=mems,
**kwargs,
)
outputs = F.log_softmax(lm_logits, dim=-1)
log_probs = []
for output, choices, choice_index in zip(outputs, choice_ids, choice_indices):
log_probs_single = []
for choice, choice_target_id in zip(choices, choice_index):
tmp = output[choice_target_id, choice]
log_probs_single.append(tmp.sum())
log_probs.append(flow.stack(log_probs_single))
log_probs = flow.stack(log_probs)
loss = None
if labels is not None:
loss = self.loss_func(log_probs, labels)
return {"loss": loss, "logits": log_probs, "lm_logits": lm_logits, "mems": mem_layers}
class GLMForConditionalGeneration(nn.Module, Generator):
@configurable
def __init__(
self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length=1024,
embedding_dropout_prob=0.0,
attention_dropout_prob=0.0,
output_dropout_prob=0.0,
layernorm_epsilon=1e-5,
initializer_range=0.02,
use_scaled_init_for_output_weights=True,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
amp_enabled=False,
block_position_encoding=False,
attention_scale=1.0,
padding_idx=None,
cfg=None,
):
super().__init__()
self.cfg = cfg
self.glm = GLMModel(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
initializer_range=initializer_range,
use_scaled_init_for_output_weights=use_scaled_init_for_output_weights,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
amp_enabled=amp_enabled,
block_position_encoding=block_position_encoding,
attention_scale=attention_scale,
padding_idx=padding_idx,
cfg=cfg,
)
self.loss_func = GLMLoss()
@classmethod
def from_config(cls, cfg):
return {
"num_layers": cfg.num_layers,
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"num_attention_heads": cfg.num_attention_heads,
"max_sequence_length": cfg.max_sequence_length,
"embedding_dropout_prob": cfg.embedding_dropout_prob,
"attention_dropout_prob": cfg.attention_dropout_prob,
"output_dropout_prob": cfg.output_dropout_prob,
"layernorm_epsilon": cfg.layernorm_epsilon,
"initializer_range": cfg.initializer_range,
"use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights,
"bias_gelu_fusion": cfg.bias_gelu_fusion,
"bias_dropout_fusion": cfg.bias_dropout_fusion,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
"apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling,
"amp_enabled": cfg.amp_enabled,
"block_position_encoding": cfg.block_position_encoding,
"attention_scale": cfg.attention_scale,
"padding_idx": cfg.padding_idx,
"cfg": cfg,
}
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
labels=None,
memory_states=None,
**kwargs,
):
lm_logits, mems = self.glm(
input_ids, position_ids, attention_mask, memory_states=memory_states, **kwargs
)
loss = None
if labels is not None:
loss = self.loss_func(lm_logits, labels)
return {"loss": loss, "logits": lm_logits, "mems": mems}
def _reorder_cache(self, past, beam_idx):
if past is None:
return past
reordered_decoder_past = ()
for layer_past_states in past:
beam_idx = beam_idx.to_global(placement=layer_past_states.placement)
reordered_decoder_past = reordered_decoder_past + (
layer_past_states.index_select(0, beam_idx),
)
return reordered_decoder_past
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
position_ids=None,
generation_attention_mask=None,
**kwargs,
):
attention_mask = generation_attention_mask
# only last token for inputs_ids if past is defined in kwargs
seq_length = input_ids.shape[1]
if past:
if position_ids is not None:
position_ids = position_ids[:, :, seq_length - 1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = attention_mask[:, :, seq_length - 1, :seq_length].unsqueeze(-2)
input_ids = input_ids[:, -1].unsqueeze(-1)
else:
if position_ids is not None:
position_ids = position_ids[:, :, :seq_length]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"memory_states": past,
}
@staticmethod
def set_pipeline_stage_id(model: nn.Module):
dist_utils = dist.get_dist_util()
if hasattr(model.glm.transformer.final_layernorm, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
if isinstance(module_block.origin, GLMEmbedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, (LMLogits, GLMLoss)):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.glm.transformer.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), GLMEmbedding):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), (LMLogits, GLMLoss)):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.glm.transformer.final_layernorm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
# GLM
2017 年, Google 提出了 Transformer 架构, 随后 BERT 、GPT、T5等预训练模型不断涌现, 并在各项任务中都不断刷新 SOTA 纪录。去年, 清华提出了 GLM 模型(https://github.com/THUDM/GLM), 不同于上述预训练模型架构,它采用了一种自回归的空白填充方法, 在 NLP 领域三种主要的任务(自然语言理解、无条件生成、有条件生成)上都取得了不错的结果。
在LiBai中主要实现了GLM推理部分的工作,训练相关内容可以参考:
- [GLM国产大模型训练加速:性能最高提升3倍,显存节省1/3,低成本上手](https://mp.weixin.qq.com/s/dkTGXuJV38KuLb4_LmM20Q)
- https://github.com/Oneflow-Inc/one-glm
## GLM-Inference
当模型规模过于庞大,单个 GPU 设备无法容纳大规模模型参数时,便捷好用的分布式训练和推理需求就相继出现,业内也随之推出相应的工具。
基于 OneFlow 构建的 LiBai 模型库让分布式上手难度降到最低,用户不需要关注模型如何分配在不同的显卡设备,只需要修改几个配置数据就可以设置不同的分布式策略。当然,加速性能更是出众。
用 LiBai 搭建的 GLM 可以便捷地实现model parallel + pipeline parallel推理, 很好地解决单卡放不下大规模模型的问题。
那么,用户如何利用大规模模型训练与推理仓库 LiBai 来构建 GLM 的分布式推理部分?下面用一个小例子解释一下。
### 分布式推理具有天然优势
要知道,模型的参数其实就是许多 tensor,也就是以矩阵的形式出现,大模型的参数也就是大矩阵,并行策略就是把大矩阵分为多个小矩阵,并分配到不同的显卡或不同的设备上,基础的 LinearLayer 在LiBai中的实现代码如下:
```python
class Linear1D(nn.Module):
def __init__(self, in_features, out_features, parallel="data", layer_idx=0, ...):
super().__init__()
if parallel == "col":
weight_sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(0)])
elif parallel == "row":
weight_sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.split(1)])
elif parallel == "data":
weight_sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
else:
raise KeyError(f"{parallel} is not supported! Only support ('data', 'row' and 'col')")
self.weight = flow.nn.Parameter(
flow.empty(
(out_features, in_features),
dtype=flow.float32,
placement=dist.get_layer_placement(layer_idx), # for pipeline parallelism placement
sbp=weight_sbp,
)
)
init_method(self.weight)
...
def forward(self, x):
...
```
在这里,用户可选择去如何切分 Linear 层的矩阵,如何切分数据矩阵,而OneFlow 中的 SBP 控制竖着切、横着切以及其他拆分矩阵的方案(模型并行、数据并行),以及通过设置 Placement 来控制这个 LinearLayer 是放在第几张显卡上(流水并行)。
所以,根据 LiBai 中各种 layer 的设计原理以及基于 OneFlow 中 tensor 自带的 SBP 和 Placement 属性的天然优势,使得用户搭建的模型能够很简单地就实现数据并行、模型并行以及流水并行操作。
### GLM 推理的 Demo 演示
这里为用户展示 LiBai 中 GLM 便捷的4卡`model parallel+pipeline parallel`推理 Demo,模型可在 HuggingFace 上获取:https://huggingface.co/models?filter=glm
#### glm-10b的文件结构
```python
$ tree data
path/to/glm-10b
├── added_tokens.json
├── vocab.json
├── merges.txt
├── config.json
└── pytorch_model.bin
```
#### 推理
运行以下代码:
```bash
# 运行前修改 glm_inference.py 中 `pad_token_id=0, eos_token_id=50258, bos_token_id=50000`
python3 -m oneflow.distributed.launch --nproc_per_node 4 demo.py
```
```python
# model parallel + pipeline parallel demo
import oneflow as flow
from projects.GLM.tokenizer.glm_tokenizer import GLMGPT2Tokenizer
from libai.utils import distributed as dist
from projects.GLM.configs.glm_inference import cfg
from projects.GLM.modeling_glm import GLMForConditionalGeneration
from projects.GLM.utils.glm_loader import GLMLoaderHuggerFace
from omegaconf import DictConfig
# 只需简单配置并行方案
parallel_config = DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=2,
pipeline_parallel_size=2,
pipeline_num_layers=2 * 24
)
)
dist.setup_dist_util(parallel_config)
tokenizer = GLMGPT2Tokenizer.from_pretrained("/path/to/glm-10b")
input_ids = tokenizer.encode(
[
"Ng is an adjunct professor at [MASK] (formerly associate professor and Director of its Stanford AI Lab or SAIL ). Also a pioneer in online education, Ng co-founded Coursera and deeplearning.ai."
],
return_tensors="of",
)
inputs = {"input_ids": input_ids, "attention_mask": flow.ones(input_ids.size())}
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
placement = dist.get_layer_placement(0)
loader = GLMLoaderHuggerFace(GLMForConditionalGeneration, cfg, "/path/to/glm-10b")
model = loader.load()
outputs = model.generate(
inputs=inputs['input_ids'].to_global(sbp=sbp, placement=placement),
position_ids=inputs['position_ids'].to_global(sbp=sbp, placement=placement),
generation_attention_mask=inputs['generation_attention_mask'].to_global(sbp=sbp, placement=placement),
max_length=512
)
res = tokenizer.decode(outputs[0])
if dist.is_main_process():
print(res)
>>> [CLS] Ng is an adjunct professor at [MASK] (formerly associate professor and Director of its Stanford AI Lab or SAIL ). Also a pioneer in online education, Ng co-founded Coursera and deeplearning.ai.<|endoftext|> <|startofpiece|> Stanford University and a co-founder of <|endofpiece|>
```
#### glm-10b-chinese的文件结构
```python
$ tree data
path/to/glm-10b-chinese
├── added_tokens.json
├── cog-pretrain.model
├── config.json
└── pytorch_model.bin
```
#### 推理
运行以下代码:
```bash
# 运行前修改 glm_inference.py 中 `pad_token_id=50000, eos_token_id=50007, bos_token_id=None`
python3 -m oneflow.distributed.launch --nproc_per_node 4 demo.py
```
```python
# model parallel + pipeline parallel demo
import oneflow as flow
from projects.GLM.tokenizer.glm_tokenizer import GLMChineseTokenzier
from libai.utils import distributed as dist
from projects.GLM.configs.glm_inference import cfg
from projects.GLM.modeling_glm import GLMForConditionalGeneration
from projects.GLM.utils.glm_loader import GLMLoaderHuggerFace
from omegaconf import DictConfig
# 只需简单配置并行方案
parallel_config = DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=2,
pipeline_parallel_size=2,
pipeline_num_layers=2 * 24
)
)
dist.setup_dist_util(parallel_config)
tokenizer = GLMChineseTokenzier.from_pretrained("/path/to/glm-10b-chinese")
input_ids = tokenizer.encode(
[
"凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。"
],
return_tensors="of",
)
inputs = {"input_ids": input_ids, "attention_mask": flow.ones(input_ids.size())}
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
placement = dist.get_layer_placement(0)
loader = GLMLoaderHuggerFace(
GLMForConditionalGeneration,
cfg,
"/path/to/glm-10b-chinese",
embedding_dropout_prob=0,
attention_dropout_prob=0,
output_dropout_prob=0,
)
model = loader.load()
outputs = model.generate(
inputs=inputs['input_ids'].to_global(sbp=sbp, placement=placement),
position_ids=inputs['position_ids'].to_global(sbp=sbp, placement=placement),
generation_attention_mask=inputs['generation_attention_mask'].to_global(sbp=sbp, placement=placement),
max_length=512
)
res = tokenizer.decode(outputs[0])
if dist.is_main_process():
print(res)
>>> [CLS] 凯旋门位于意大利米兰市古城堡旁1807年为纪念 [MASK] 而建,门高25米,顶上矗立两武士青铜古兵车铸像 <|endoftext|> <|startofpiece|> 拿破仑军队攻克米兰城 <|endofpiece|>
```
#### 使用 One-GLM 训练的模型进行推理
LiBai对于OneFlow的模型加载同样方便,如果你希望使用one-glm训练后的模型进行推理,只需简单的将上述demo中的 GLMLoaderHuggerFace 替换为 GLMLoaderLiBai。
\ No newline at end of file
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from shutil import copyfile
from typing import List, Optional, Tuple
import oneflow as flow
import sentencepiece as spm
from libai.tokenizer import BertTokenizer, GPT2Tokenizer, PreTrainedTokenizer, RobertaTokenizer
logger = logging.getLogger(__name__)
class GLMTokenizerMixin(PreTrainedTokenizer):
@property
def sop_token(self) -> Optional[str]:
return "<|startofpiece|>"
@property
def sop_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the start token in the vocabulary, used when training a model with
autoregressive blank filling.
"""
return self.convert_tokens_to_ids(self.sop_token)
@property
def eop_token(self) -> Optional[str]:
return "<|endofpiece|>"
@property
def eop_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the end token in the vocabulary, used when training a model with
autoregressive blank filling.
"""
return self.convert_tokens_to_ids(self.eop_token)
@property
def gmask_token_id(self) -> int:
return self.convert_tokens_to_ids("[gMASK]")
@property
def smask_token_id(self) -> int:
return self.convert_tokens_to_ids("[sMASK]")
@property
def mask_token_ids(self):
return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
def _build_input_for_multiple_choice(self, context, choices):
context_id = context["input_ids"]
if flow.is_tensor(context_id):
context_id = context_id.tolist()
division = len(context_id)
mask_position = context_id.index(self.mask_token_id)
token = flow.tensor(context_id, dtype=flow.long)
attention_mask = [context["attention_mask"].expand(division, -1)]
position_id = flow.arange(division, dtype=flow.long)
block_position_id = flow.zeros(division, dtype=flow.long)
choice_ids, choice_indices = [], []
for choice_str in choices:
res = self.encode(choice_str)
choice = flow.tensor(res, dtype=flow.long)
choice_ids.append(choice)
choice_indices.append(
flow.arange(len(token), len(token) + len(choice), dtype=flow.long)
)
attention_mask.append(flow.tril(flow.ones((len(choice), len(choice)), dtype=flow.long)))
token = flow.cat(
(token, flow.tensor([self.sop_token_id], dtype=flow.long), choice[:-1])
)
position_id = flow.cat(
(position_id, flow.tensor([mask_position] * len(choice), dtype=flow.long))
)
block_position_id = flow.cat(
(block_position_id, flow.arange(1, 1 + len(choice), dtype=flow.long))
)
attention_mask = flow.block_diag(*attention_mask)
attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
return {
"input_ids": token,
"position_ids": flow.stack((position_id, block_position_id)),
"attention_mask": attention_mask,
"choice_ids": choice_ids,
"choice_indices": choice_indices,
}
def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
pad_length = max_seq_length - len(tokens)
attention_mask = flow.nn.functional.pad(
attention_mask,
(0, pad_length, 0, pad_length),
mode="constant",
value=0,
)
tokens = flow.cat((tokens, flow.zeros(pad_length, dtype=flow.long)))
position_ids = flow.cat(
(position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1
)
return tokens, position_ids, attention_mask
def _collate(self, samples):
TILE = 1
length_to_pad = (
(max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
)
token_batch, position_id_batch, attention_mask_batch = [], [], []
choices_batch, choice_target_ids_batch = [], []
for sample in samples:
token, position_id, attention_mask = self._pad_batch(
sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
)
token_batch.append(token)
position_id_batch.append(position_id)
attention_mask_batch.append(attention_mask)
choices_batch.append(sample["choice_ids"])
choice_target_ids_batch.append(sample["choice_indices"])
return {
"input_ids": flow.stack(token_batch),
"position_ids": flow.stack(position_id_batch),
"attention_mask": flow.stack(attention_mask_batch).unsqueeze(1),
"choice_ids": choices_batch,
"choice_indices": choice_target_ids_batch,
}
def build_inputs_for_multiple_choice(self, model_input, choices, max_length=None):
samples = [
{key: value[i] for key, value in model_input.items()}
for i in range(len(model_input["input_ids"]))
]
samples = [
self._build_input_for_multiple_choice(sample, choice)
for sample, choice in zip(samples, choices)
]
inputs = self._collate(samples)
return inputs
def build_inputs_for_generation(
self, model_input, max_gen_length=512, targets=None, padding=False
):
mask_ids = self.mask_token_ids
input_ids = model_input["input_ids"]
batch_size, seq_length = input_ids.shape[:2]
position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
position_ids, block_position_ids = [], []
labels = None
if targets is not None:
is_batched = isinstance(targets, (list, tuple))
targets = self.encode(targets)
if not is_batched:
targets = [targets]
assert len(targets) == len(input_ids)
targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
if not padding:
max_gen_length = max(map(len, targets))
targets = [[self.sop_token_id] + target for target in targets]
labels = [target[1:] for target in targets]
targets = [
target + [self.pad_token_id] * (max_gen_length + 1 - len(target))
for target in targets
]
labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
targets = flow.tensor(targets, dtype=input_ids.dtype)
labels = flow.tensor(labels, dtype=input_ids.dtype)
labels = flow.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
for i in range(batch_size):
mask_positions = []
for mask_id in mask_ids:
mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
if not mask_positions:
raise ValueError("Cannot find mask token in the input")
mask_positions.sort()
mask_pos = mask_positions[0]
position_ids.append(position_id + [mask_pos] * max_gen_length)
block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
position_ids = flow.tensor(position_ids, dtype=input_ids.dtype)
block_position_ids = flow.tensor(block_position_ids, dtype=input_ids.dtype)
position_ids = flow.stack((position_ids, block_position_ids), dim=1)
attention_mask = model_input["attention_mask"]
attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
generation_attention_mask = (
flow.cat(
[
attention_mask.new_zeros((seq_length, max_gen_length)),
flow.tril(attention_mask.new_ones((max_gen_length, max_gen_length))),
],
dim=0,
)
.unsqueeze(0)
.expand(batch_size, -1, -1)
)
attention_mask = flow.cat((attention_mask, generation_attention_mask), dim=2)
attention_mask = attention_mask.unsqueeze(1)
if targets is None:
input_ids = flow.cat(
(input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1
)
else:
input_ids = flow.cat((input_ids, targets[:, :-1]), dim=1)
batch = {"input_ids": input_ids, "position_ids": position_ids}
if labels is None:
batch["generation_attention_mask"] = attention_mask
else:
batch["attention_mask"] = attention_mask
batch["labels"] = labels
return batch
class GLMRobertaTokenizer(RobertaTokenizer, GLMTokenizerMixin):
model_input_names = ["input_ids", "position_ids", "attention_mask"]
truncation_side: str = "left"
@property
def gmask_token_id(self) -> int:
raise NotImplementedError("The model doesn't support gMASK")
@property
def smask_token_id(self) -> int:
raise NotImplementedError("The model doesn't support sMASK")
@property
def mask_token_ids(self):
return [self.mask_token_id]
class GLMChineseTokenzier(GLMTokenizerMixin):
vocab_files_names = {"vocab_file": "cog-pretrain.model"}
truncation_side: str = "left"
def __init__(
self,
vocab_file,
eos_token="<|endoftext|>",
unk_token="[UNK]",
pad_token="<|endoftext|>",
additional_special_tokens=["<|startofpiece|>", "<|endofpiece|>", "[gMASK]", "[sMASK]"],
add_bos_token=False,
**kwargs,
):
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.add_bos_token = add_bos_token
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
self._eos_token = "<|endoftext|>"
self._unk_token = "[UNK]"
self._pad_token = "<|endoftext|>"
self._cls_token = "[CLS]"
self._mask_token = "[MASK]"
@property
def vocab_size(self):
return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.sp_model.IdToPiece(index)
def convert_tokens_to_string(self, tokens):
return self.sp_model.decode(tokens)
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ self.vocab_files_names["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
self.vocab_file
):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the
appropriate special tokens.
"""
assert token_ids_1 is None
cls = [self.cls_token_id]
eos = [self.eos_token_id]
return cls + token_ids_0 + eos
class GLMGPT2Tokenizer(GPT2Tokenizer, GLMTokenizerMixin):
model_input_names = ["input_ids", "position_ids", "attention_mask"]
truncation_side: str = "left"
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
add_bos_token=False,
**kwargs,
):
super().__init__(
vocab_file,
merges_file,
errors,
unk_token,
bos_token,
eos_token,
add_bos_token,
**kwargs,
)
self.cls_token = "[CLS]"
self.mask_token = "[MASK]"
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the
appropriate special tokens.
"""
assert token_ids_1 is None
cls = [self.cls_token_id]
eos = [self.eos_token_id]
return cls + token_ids_0 + eos
class GLMBertTokenizer(BertTokenizer, GLMTokenizerMixin):
model_input_names = ["input_ids", "position_ids", "attention_mask"]
truncation_side: str = "left"
@property
def gmask_token_id(self) -> int:
raise NotImplementedError("The model doesn't support gMASK")
@property
def smask_token_id(self) -> int:
raise NotImplementedError("The model doesn't support sMASK")
@property
def mask_token_ids(self):
return [self.mask_token_id]
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from libai.models.utils import ModelLoaderHuggerFace, ModelLoaderLiBai
class GLMLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is GLM's prefix in Transformers.
base_model_prefix_2 is GLM's prefix in LiBai."""
self.base_model_prefix_1 = "glm"
self.base_model_prefix_2 = "glm"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
old_keys = list(oneflow_state_dict.keys())
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix1 = self.base_model_prefix_1 + "." if has_prefix else ""
prefix2 = "glm." if has_prefix else ""
# Convert Embedding layers.
new_key = prefix2 + "embeddings.word_embeddings.weight"
old_keys.remove(prefix1 + "word_embeddings.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "word_embeddings.weight")
if cfg.get("block_position_encoding", False) is True:
new_key = prefix2 + "embeddings.position_embeddings.weight"
old_keys.remove(prefix1 + "transformer.position_embeddings.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(
prefix1 + "transformer.position_embeddings.weight"
)
new_key = prefix2 + "embeddings.block_position_embeddings.weight"
old_keys.remove(prefix1 + "transformer.block_position_embeddings.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(
prefix1 + "transformer.block_position_embeddings.weight"
)
# Convert other layers.
for key in old_keys:
if "query_key_value" in key:
qkv = oneflow_state_dict.pop(key)
qkv = self._fix_qkv_ordering(qkv, head_size, num_heads)
oneflow_state_dict[prefix2 + key] = qkv
else:
oneflow_state_dict[prefix2 + key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
for k, v in cfg_dict.items():
self._update_cfg(k, v)
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class GLMLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "glm"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
old_keys = list(oneflow_state_dict.keys())
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix1 = self.base_model_prefix_1 + "." if has_prefix else ""
prefix2 = "glm." if has_prefix else ""
# Convert Embedding layers.
new_key = prefix2 + "embeddings.word_embeddings.weight"
old_keys.remove(prefix1 + "word_embeddings.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "word_embeddings.weight")
if cfg.get("block_position_encoding", False) is True:
new_key = prefix2 + "embeddings.position_embeddings.weight"
old_keys.remove(prefix1 + "transformer.position_embeddings.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(
prefix1 + "transformer.position_embeddings.weight"
)
new_key = prefix2 + "embeddings.block_position_embeddings.weight"
old_keys.remove(prefix1 + "transformer.block_position_embeddings.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(
prefix1 + "transformer.block_position_embeddings.weight"
)
# Convert other layers.
for key in old_keys:
if "query_key_value" in key:
qkv = oneflow_state_dict.pop(key)
qkv = self._fix_qkv_ordering(qkv, head_size, num_heads)
oneflow_state_dict[prefix2 + key] = qkv
else:
oneflow_state_dict[prefix2 + key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
## MAE in LiBai
**Masked Autoencoders Are Scalable Vision Learners**
Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick
[[`arXiv`](https://arxiv.org/abs/2111.06377)] [[`BibTeX`](#Citation)]
<p align="center">
<img src="https://user-images.githubusercontent.com/11435359/146857310-f258c86c-fde6-48e8-9cee-badd2b21bd2c.png" width="480">
</p>
This is the OneFlow re-implementation of MAE based on [LiBai](https://libai.readthedocs.io/).
## Catelog
- [x] MAE pretraining code
- [x] MAE finetune code
## Supported parallel mode and task
Based on [libai.layers](https://libai.readthedocs.io/en/latest/modules/libai.layers.html), MAE model is automatically configured with the following parallelism mode.
<table class="docutils">
<tbody>
<tr>
<th width="80"> Model </th>
<th valign="bottom" align="left" width="120">Data Parallel</th>
<th valign="bottom" align="left" width="120">Tensor Parallel</th>
<th valign="bottom" align="left" width="120">Pipeline Parallel</th>
</tr>
<tr>
<td align="left"> <b> MAE pretrain </b> </td>
<td align="left">&#10004;</td>
<td align="left">-</td>
<td align="left">-</td>
</tr>
<tr>
<td align="left"> <b> MAE finetune </b> </td>
<td align="left">&#10004;</td>
<td align="left">&#10004;</td>
<td align="left">&#10004;</td>
</tr>
</tbody>
</table>
## Usage
### Installation
Please see [LiBai Installation](https://libai.readthedocs.io/en/latest/tutorials/get_started/Installation.html) to install LiBai
### Prepare the Data
Please see [Prepare the Data](https://libai.readthedocs.io/en/latest/tutorials/get_started/quick_run.html#prepare-the-data).
### Pretraining
Pretraining MAE on 8 GPUs using data parallelism.
```bash
cd /path/to/libai
bash tools/train.sh projects/MAE/train_net.py projects/MAE/configs/mae_pretraining.py 8
```
### Finetuning
1. Setup the weights for finetuning in [mae_finetune.py](./configs/mae_finetune.py) as follows:
```python
# mae_funetune.py
finetune.enable = True # only load weight if enable is True
finetune.weight_style = "oneflow" # Set "oneflow" for loading oneflow checkpoints
finetune.path = "/path/to/checkpoint" # the checkpoint directory
```
If you feel confused about the checkpoint format here, please refer to [Load and Save a Checkpoint in LiBai](https://libai.readthedocs.io/en/latest/tutorials/basics/Load_and_Save_Checkpoint.html) for more details.
1. Finetune MAE on 8 GPUs using data parallelism.
```bash
cd /path/to/libai
bash tools/train.sh projects/MAE/train_net.py projects/MAE/configs/mae_finetune.py 8
```
**Notes:** if you want to finetune MAE models using different parallel strategies, please refer to the [Distributed Configuration Tutorial](https://libai.readthedocs.io/en/latest/tutorials/basics/Distributed_Configuration.html)
### Evaluation
Evaluate MAE model under LiBai on 8 GPUs:
```bash
cd /path/to/libai
bash tools/train.sh projects/MAE/train_net.py projects/MAE/configs/mae_finetune.py 8 --eval-only
```
## Advanced Usage
### Finetune MAE with pytorch pretrained checkpoint
You can download pytorch pretrained weight from [MAE official repo](https://github.com/facebookresearch/mae#fine-tuning-with-pre-trained-checkpoints) and finetune them in LiBai by updating the [mae_finetune.py](./configs/mae_finetune.py) as follows:
```python
finetune.enable = True # only load weight if enable is True
finetune.weight_style = "pytorch" # Set "pytorch" for loading torch checkpoints
finetune.path = "/path/to/mae_finetuned_vit_base.pth"
```
Run finetuning on 8 GPUs:
```bash
cd /path/to/libai
bash tools/train.sh projects/MAE/train_net.py projects/MAE/configs/mae_finetune.py 8
```
## Citation
```BibTeX
@article{he2021masked,
title={Masked autoencoders are scalable vision learners},
author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv preprint arXiv:2111.06377},
year={2021}
}
```
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment