Commit 9fdb7dab authored by yuguo960516's avatar yuguo960516
Browse files

bloom

parents
Pipeline #150 failed with stages
in 0 seconds
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn as nn
from flowvision.layers import trunc_normal_
from flowvision.models import to_2tuple
from libai.config.config import configurable
from libai.layers import MLP, DropPath, LayerNorm, Linear
from libai.utils import distributed as dist
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query,key,value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
fused_bias_add_dropout=False,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
flow.zeros(
(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
) # 2*Wh-1 * 2*Ww-1, nH
trunc_normal_(self.relative_position_bias_table, std=0.02)
# get pair-wise relative position index for each token inside the window
coords_h = flow.arange(self.window_size[0])
coords_w = flow.arange(self.window_size[1])
coords = flow.stack(flow.meshgrid(*[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] = (
relative_coords[:, :, 0] + self.window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] = relative_coords[:, :, 1] + self.window_size[1] - 1
relative_coords[:, :, 0] = relative_coords[:, :, 0] * (2 * self.window_size[1] - 1)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer(
"relative_position_index",
relative_position_index.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
self.qkv = Linear(dim, dim * 3, bias=qkv_bias, layer_idx=layer_idx)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = Linear(dim, dim, layer_idx=layer_idx)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.fused_bias_add_dropout = fused_bias_add_dropout
self.p = proj_drop
def forward(self, x, mask):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
# attn = flow.matmul(q, k.transpose(-2, -1))
attn = flow.matmul(q, k, transpose_b=True)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
unsqueeze_relative_position_bias = relative_position_bias.unsqueeze(0)
attn = attn + unsqueeze_relative_position_bias
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = flow.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
if self.fused_bias_add_dropout:
x = flow._C.matmul(x, self.proj.weight, transpose_a=False, transpose_b=True)
x = flow._C.fused_bias_add_dropout(x, self.proj.bias, p=self.p, axis=2)
else:
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LayerNorm,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.layer_idx = layer_idx
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim, layer_idx=layer_idx)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
fused_bias_add_dropout=True,
layer_idx=layer_idx,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim, layer_idx=layer_idx)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
hidden_size=dim,
ffn_hidden_size=mlp_hidden_dim,
output_dropout_prob=drop,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
layer_idx=layer_idx,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = flow.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt = cnt + 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
attn_mask = attn_mask.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = flow.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = flow.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=LayerNorm, layer_idx=0):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = Linear(4 * dim, 2 * dim, bias=False, layer_idx=layer_idx)
self.norm = norm_layer(4 * dim, layer_idx=layer_idx)
self.layer_idx = layer_idx
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = flow.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, layer_idx=0
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
).to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, layer_idx=layer_idx)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
downsample (nn.Module | None, optional): Downsample at the end of the layer. Default: None
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=LayerNorm,
downsample=None,
layer_id_offset=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.layer_id_offset = layer_id_offset
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
layer_idx=layer_id_offset + i,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=norm_layer,
layer_idx=layer_id_offset + depth - 1,
)
else:
self.downsample = None
def forward(self, x):
layer_idx = self.layer_id_offset
for i in range(len(self.blocks)):
x = x.to_global(placement=dist.get_layer_placement(layer_idx))
x = self.blocks[i](x)
layer_idx += 1
if self.downsample is not None:
x = self.downsample(x)
return x
class SwinTransformer(nn.Module):
"""Swin Transformer in LiBai.
LiBai implement of:
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/pdf/2103.14030>`_
Args:
img_size (int, tuple(int)): Input image size. Default 224
patch_size (int, tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: libai.layers.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
loss_func (callable, optional): Loss function for computing the total loss
between logits and labels
"""
@configurable
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=LayerNorm,
ape=False,
patch_norm=True,
loss_func=None,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
layer_idx=0,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
flow.zeros(1, num_patches, embed_dim),
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
layer_id_offset = 0
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
layer_id_offset=layer_id_offset,
)
layer_id_offset += depths[i_layer]
self.layers.append(layer)
self.norm = norm_layer(self.num_features, layer_idx=-1)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
Linear(self.num_features, num_classes, layer_idx=-1)
if num_classes > 0
else nn.Identity()
)
# Loss func
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@classmethod
def from_config(cls, cfg):
return {
"img_size": cfg.img_size,
"patch_size": cfg.patch_size,
"in_chans": cfg.in_chans,
"num_classes": cfg.num_classes,
"embed_dim": cfg.embed_dim,
"depths": cfg.depths,
"num_heads": cfg.num_heads,
"window_size": cfg.window_size,
"mlp_ratio": cfg.mlp_ratio,
"qkv_bias": cfg.qkv_bias,
"qk_scale": cfg.qk_scale,
"drop_rate": cfg.drop_rate,
"drop_path_rate": cfg.drop_path_rate,
"ape": cfg.ape,
"patch_norm": cfg.patch_norm,
"loss_func": cfg.loss_func,
}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = flow.flatten(x, 1)
return x
def forward(self, images, labels=None):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x = self.forward_features(images)
x = self.head(x)
if labels is not None and self.training:
losses = self.loss_func(x, labels)
return {"losses": losses}
else:
return {"prediction_scores": x}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.patch_embed, "config"):
# Old API in OneFlow 0.8
model.patch_embed.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, PatchMerging):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
model.patch_embed.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), PatchMerging):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
if hasattr(module_block, "origin"):
# Old API in OneFlow 0.8
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).activation_checkpointing = True
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
from flowvision.layers import trunc_normal_
from flowvision.models import to_2tuple
from libai.config.config import configurable
from libai.layers import MLP, DropPath, LayerNorm, Linear
from libai.utils import distributed as dist
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value.
Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.0,
proj_drop=0.0,
pretrained_window_size=[0, 0],
fused_bias_add_dropout=False,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.fused_bias_add_dropout = fused_bias_add_dropout
self.num_heads = num_heads
self.layer_idx = layer_idx
self.p = proj_drop
self.logit_scale = nn.Parameter(
flow.log(
10
* flow.ones(
1,
num_heads,
1,
1,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
),
requires_grad=True,
)
# NOTE: generate meta network, using mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
Linear(2, 512, bias=True, layer_idx=layer_idx),
nn.ReLU(inplace=True),
Linear(512, num_heads, bias=False, layer_idx=layer_idx),
).to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
# NOTE: get relative_coords_table
relative_coords_h = flow.arange(
-(self.window_size[0] - 1), self.window_size[0], dtype=flow.float32
)
relative_coords_w = flow.arange(
-(self.window_size[1] - 1), self.window_size[1], dtype=flow.float32
)
relative_coords_table = (
flow.stack(flow.meshgrid(*[relative_coords_h, relative_coords_w]))
.permute(1, 2, 0)
.contiguous()
.unsqueeze(0)
) # 1, 2*Wh-1, 2*Ww-1, 2
# NOTE: For any relative coordinate, constrain it to -8~8 (window size)
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] = relative_coords_table[:, :, :, 0] / (
pretrained_window_size[0] - 1
)
relative_coords_table[:, :, :, 1] = relative_coords_table[:, :, :, 1] / (
pretrained_window_size[1] - 1
)
else:
relative_coords_table[:, :, :, 0] = relative_coords_table[:, :, :, 0] / (
self.window_size[0] - 1
)
relative_coords_table[:, :, :, 1] = relative_coords_table[:, :, :, 1] / (
self.window_size[1] - 1
)
relative_coords_table = relative_coords_table * 8
# NOTE: y=sign(x)*log(|x|+1)
relative_coords_table = (
flow.sign(relative_coords_table)
* flow.log2(flow.abs(relative_coords_table) + 1.0)
/ math.log2(8.0)
)
self.register_buffer(
"relative_coords_table",
relative_coords_table.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
# NOTE: get pair-wise relative position index for each token inside the window
coords_h = flow.arange(self.window_size[0])
coords_w = flow.arange(self.window_size[1])
coords = flow.stack(flow.meshgrid(*[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] = (
relative_coords[:, :, 0] + self.window_size[0] - 1
) # shift to start from 0
relative_coords[:, :, 1] = relative_coords[:, :, 1] + self.window_size[1] - 1
relative_coords[:, :, 0] = relative_coords[:, :, 0] * (2 * self.window_size[1] - 1)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer(
"relative_position_index",
relative_position_index.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
)
self.qkv = Linear(dim, dim * 3, bias=False, layer_idx=layer_idx)
if qkv_bias:
self.q_bias = nn.Parameter(
flow.zeros(
dim,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
self.v_bias = nn.Parameter(
flow.zeros(
dim,
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = Linear(dim, dim, layer_idx=layer_idx)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = flow.concat(
[
self.q_bias,
flow.zeros(
self.v_bias.shape,
requires_grad=False,
placement=dist.get_layer_placement(
self.layer_idx, device_type=self.v_bias.placement.type
),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
self.v_bias,
],
dim=0,
)
qkv = self.qkv(x) + qkv_bias.unsqueeze(0).unsqueeze(0)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# NOTE: cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
# NOTE: a learnable scalar
logit_scale = flow.clamp(self.logit_scale, min=-1e6, max=math.log(1.0 / 0.01)).exp()
attn = attn * logit_scale
# NOTE: use relative_coords_table and meta network to generate relative_position_bias
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
-1, self.num_heads
)
relative_position_bias = relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
# NOTE: constrained to a range of -16~16
relative_position_bias = 16 * flow.sigmoid(relative_position_bias).unsqueeze(0)
attn = attn + relative_position_bias
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
if self.fused_bias_add_dropout:
x = flow._C.matmul(x, self.proj.weight, transpose_a=False, transpose_b=True)
x = flow._C.fused_bias_add_dropout(x, self.proj.bias, p=self.p, axis=2)
else:
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=LayerNorm,
pretrained_window_size=0,
layer_idx=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.layer_idx = layer_idx
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim, layer_idx=layer_idx)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size),
fused_bias_add_dropout=True,
layer_idx=layer_idx,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim, layer_idx=layer_idx)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
hidden_size=dim,
ffn_hidden_size=mlp_hidden_dim,
output_dropout_prob=drop,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
layer_idx=layer_idx,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = flow.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt = cnt + 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = (
attn_mask.masked_fill(attn_mask != 0, float(-100.0))
.masked_fill(attn_mask == 0, float(0.0))
.to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = flow.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = flow.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# NOTE: res-post-norm
x = shortcut + self.drop_path(self.norm1(x))
# NOTE: res-post-norm
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=LayerNorm, layer_idx=0):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = Linear(4 * dim, 2 * dim, bias=False, layer_idx=layer_idx)
# NOTE: swinv2-> 2*dim, swin-> 4*dim
self.norm = norm_layer(2 * dim, layer_idx=layer_idx)
self.layer_idx = layer_idx
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = flow.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
# NOTE: post-res-norm, a change that swin-v2 compared to swin
x = self.reduction(x)
x = self.norm(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
Default: None
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=LayerNorm,
downsample=None,
pretrained_window_size=0,
layer_id_offset=0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.layer_id_offset = layer_id_offset
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size,
layer_idx=layer_id_offset + i,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution,
dim=dim,
norm_layer=norm_layer,
layer_idx=layer_id_offset + depth - 1,
)
else:
self.downsample = None
def forward(self, x):
layer_idx = self.layer_id_offset
for blk in self.blocks:
x = x.to_global(
placement=dist.get_layer_placement(layer_idx, device_type=x.placement.type)
)
x = blk(x)
layer_idx += 1
if self.downsample is not None:
x = self.downsample(x)
return x
def _init_respostnorm(self):
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, layer_idx=0
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
).to_global(
placement=dist.get_layer_placement(layer_idx),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model \
({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
class SwinTransformerV2(nn.Module):
r"""Swin Transformer
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
"""
@configurable
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=LayerNorm,
ape=False,
patch_norm=True,
pretrained_window_sizes=[0, 0, 0, 0],
loss_func=None,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
layer_idx=0,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
flow.zeros(
1,
num_patches,
embed_dim,
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
)
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
layer_id_offset = 0
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer],
layer_id_offset=layer_id_offset,
)
layer_id_offset += depths[i_layer]
self.layers.append(layer)
self.norm = norm_layer(self.num_features, layer_idx=-1)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
Linear(self.num_features, num_classes, layer_idx=-1)
if num_classes > 0
else nn.Identity()
)
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
self.apply(self._init_weights)
for bly in self.layers:
bly._init_respostnorm()
def _init_weights(self, m):
if isinstance(m, Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@classmethod
def from_config(cls, cfg):
return {
"img_size": cfg.img_size,
"patch_size": cfg.patch_size,
"in_chans": cfg.in_chans,
"num_classes": cfg.num_classes,
"embed_dim": cfg.embed_dim,
"depths": cfg.depths,
"num_heads": cfg.num_heads,
"window_size": cfg.window_size,
"mlp_ratio": cfg.mlp_ratio,
"qkv_bias": cfg.qkv_bias,
"drop_rate": cfg.drop_rate,
"drop_path_rate": cfg.drop_path_rate,
"ape": cfg.ape,
"patch_norm": cfg.patch_norm,
"pretrained_window_sizes": cfg.pretrained_window_sizes,
"loss_func": cfg.loss_func,
}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = flow.flatten(x, 1)
return x
def forward(self, images, labels=None):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x = self.forward_features(images)
x = self.head(x)
if labels is not None and self.training:
losses = self.loss_func(x, labels)
return {"losses": losses}
else:
return {"prediction_scores": x}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.patch_embed, "config"):
# Old API in OneFlow 0.8
model.patch_embed.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, PatchMerging):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
model.patch_embed.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), PatchMerging):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
model.norm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.avgpool.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
@staticmethod
def set_activation_checkpoint(model):
for module_block in model.modules():
if hasattr(module_block, "origin"):
# Old API in OneFlow 0.8
if isinstance(module_block.origin, SwinTransformerBlock):
module_block.config.activation_checkpointing = True
else:
if isinstance(module_block.to(nn.Module), SwinTransformerBlock):
module_block.to(flow.nn.graph.GraphModule).activation_checkpointing = True
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn as nn
from libai.config import configurable
from libai.layers import (
Embedding,
LayerNorm,
LMLogits,
ParallelCrossEntropyLoss,
TransformerLayer,
VocabEmbedding,
)
from libai.layers.attention import AttnMaskType
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist
class ExtendedMask(flow.nn.Module):
def forward(self, attention_mask):
return attention_mask.unsqueeze(1)
class T5Embedding(flow.nn.Module):
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method=flow.nn.init.xavier_normal_,
amp_enabled=False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.word_embeddings = VocabEmbedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.position_embeddings = Embedding(
num_embeddings=max_sequence_length,
embedding_dim=hidden_size,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.position_ids = flow.arange(
max_sequence_length,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
).unsqueeze(0)
self.embedding_dropout = flow.nn.Dropout(embedding_dropout_prob)
def forward(self, input_ids, past_length=0):
seq_length = input_ids.size()[1]
position_ids = self.position_ids[:, past_length : past_length + seq_length]
position_ids = position_ids.expand_as(input_ids).to_global(sbp=input_ids.sbp)
word_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
embeddings = self.embedding_dropout(embeddings)
return embeddings
class T5Model(flow.nn.Module):
"""T5 Model that outputs logits.
Args:
vocab_size (int): The size of vocabulary file.
hidden_size (int): The size of hidden states.
hidden_layers (int): The number of ``TransformerLayer`` in the encoder and decoder.
num_attention_heads (int):
The number of attention heads for each attention layer of ``TransformerLayer``.
intermediate_size (int):
The size of intermediate layer in feed-forward network for each ``TransformerLayer``.
embedding_dropout_prob (float): The dropout ratio for the output of T5Embedding Layer.
hidden_dropout_prob (float): The dropout ratio for the output for each ``TransformerLayer``.
attention_probs_dropout_prob (float):
The dropout ratio for the output of each attention layer in ``TransformerLayer``.
max_position_embeddings (int):
Max sequence length of input, defines the shape of Position Embeddings
in ``T5Emebedding``.
initializer_range (float, optional):
Sigma of the normal distribution in the initialization method. Defaults to 0.02.
layernorm_eps (float, optional): The epsilon of LayerNorm layer. Defaults to 1e-12.
bias_gelu_fusion (bool, optional):
Whether or not to fuse the computing of bias and gelu. Defaults to ``False``.
bias_dropout_fusion (bool, optional):
Whether or not to fuse the computing of dropout and bias. Defaults to ``False``.
scale_mask_softmax_fusion (bool, optional):
Whether to fuse the computing of mask and softmax in attention layers.
Defaults to ``False``.
apply_query_key_layer_scaling (bool, optional):
Whether or not to use layer index related scaling in computing attention scores.
If ``True``, the scaling factor equals to sqrt(d) * (layer_index + 1).
Defaults to ``True``.
apply_residual_post_layernorm (bool, optional):
If set ``True``, use original BERT residual connection ordering otherwise use Megatron
BERT residual connection which is more stable when scaling model size introduced in
https://arxiv.org/pdf/1909.08053.pdf.
Default: ``False``.
amp_enabled (bool, optional):
Whether or not to set fp16 for embedding weight in T5 model. Defaults to ``False``.
"""
@configurable
def __init__(
self,
vocab_size,
hidden_size,
hidden_layers,
num_attention_heads,
intermediate_size,
embedding_dropout_prob,
hidden_dropout_prob,
attention_probs_dropout_prob,
max_position_embeddings,
initializer_range=0.02,
layernorm_eps=1e-12,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
apply_residual_post_layernorm=False,
amp_enabled=False,
) -> None:
super().__init__()
init_method = init_method_normal(initializer_range)
scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers)
self.embedding = T5Embedding(
hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_position_embeddings,
embedding_dropout_prob=embedding_dropout_prob,
init_method=init_method,
amp_enabled=amp_enabled,
)
self.extended_attn_mask = ExtendedMask()
encoder_layers = flow.nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=num_attention_heads,
is_decoder=False,
attention_dropout_prob=attention_probs_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
apply_residual_post_layernorm=apply_residual_post_layernorm,
attn_mask_type=AttnMaskType.padding,
layer_idx=i,
)
for i in range(hidden_layers)
]
)
encoder_final_layernorm = LayerNorm(
(hidden_size,),
eps=layernorm_eps,
layer_idx=hidden_layers - 1,
)
self.encoder = flow.nn.Sequential()
self.encoder.add_module("layers", encoder_layers)
self.encoder.add_module("final_layernorm", encoder_final_layernorm)
decoder_layers = flow.nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=num_attention_heads,
is_decoder=True,
attention_dropout_prob=attention_probs_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attn_mask_type=AttnMaskType.padding,
layer_idx=i,
)
for i in range(hidden_layers, 2 * hidden_layers)
]
)
decoder_final_layernorm = LayerNorm(
(hidden_size,),
eps=layernorm_eps,
layer_idx=2 * hidden_layers - 1,
)
self.decoder = flow.nn.Sequential()
self.decoder.add_module("layers", decoder_layers)
self.decoder.add_module("final_layernorm", decoder_final_layernorm)
self.past_key_values = [None] * len(self.decoder.layers)
self.encoder_states = None
self.past_length = 0
self.lm_head = LMLogits(vocab_size, bias=True)
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"hidden_layers": cfg.hidden_layers,
"num_attention_heads": cfg.num_attention_heads,
"intermediate_size": cfg.intermediate_size,
"embedding_dropout_prob": cfg.embedding_dropout_prob,
"hidden_dropout_prob": cfg.hidden_dropout_prob,
"attention_probs_dropout_prob": cfg.attention_probs_dropout_prob,
"max_position_embeddings": cfg.max_position_embeddings,
"initializer_range": cfg.initializer_range,
"layernorm_eps": cfg.layernorm_eps,
"bias_gelu_fusion": cfg.bias_gelu_fusion,
"bias_dropout_fusion": cfg.bias_dropout_fusion,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
"apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling,
"apply_residual_post_layernorm": cfg.apply_residual_post_layernorm,
"amp_enabled": cfg.amp_enabled,
}
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
use_cache=False,
):
"""
Args:
encoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for encoder.
decoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for decoder.
encoder_attn_mask (flow.BoolTensor):
Mask for encoder to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on subsequent token indices.
Mask values have the same meaning as encoder_attn_mask.
encoder_decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on encoder padded token indices.
Mask values have the same meaning as encoder_attn_mask.
use_cache (bool, optional):
It will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
Returns:
flow.Tensor: logits
"""
encoder_input_ids = encoder_input_ids.to_global(placement=dist.get_layer_placement(0))
decoder_input_ids = decoder_input_ids.to_global(placement=dist.get_layer_placement(0))
encoder_attn_mask = encoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
decoder_attn_mask = decoder_attn_mask.to_global(placement=dist.get_layer_placement(0))
encoder_decoder_attn_mask = encoder_decoder_attn_mask.to_global(
placement=dist.get_layer_placement(0)
)
if use_cache and self.encoder_states is not None:
encoder_states = self.encoder_states
else:
self.set_cache(encoder_states=None, past_key_values=None)
encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask)
enc_embedding_output = self.embedding(encoder_input_ids)
enc_hidden_states = enc_embedding_output
for layer in self.encoder.layers:
enc_hidden_states = layer(enc_hidden_states, encoder_attn_mask)
encoder_states = self.encoder.final_layernorm(enc_hidden_states)
decoder_attn_mask = self.extended_attn_mask(decoder_attn_mask)
encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask)
dec_embedding_output = self.embedding(decoder_input_ids, self.past_length)
dec_hidden_states = dec_embedding_output
if use_cache:
presents = []
for layer, past_key_value in zip(self.decoder.layers, self.past_key_values):
dec_hidden_states = layer(
dec_hidden_states,
decoder_attn_mask,
encoder_states,
encoder_decoder_attn_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
if use_cache:
dec_hidden_states, present = dec_hidden_states
presents.append(present)
if use_cache:
self.set_cache(encoder_states, past_key_values=presents)
decoder_states = self.decoder.final_layernorm(dec_hidden_states)
logits = self.lm_head(decoder_states, self.embedding.word_embeddings.weight)
return logits
def set_cache(self, encoder_states, past_key_values):
self.encoder_states = encoder_states
self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2]
if past_key_values is None:
past_key_values = [None] * len(self.decoder.layers)
assert len(past_key_values) == len(self.decoder.layers), (
f"past_key_values's length {len(past_key_values)} doesn't match "
f"decoder num_layers' length {self.decoder.layers}"
)
self.past_key_values = past_key_values
class T5Loss(flow.nn.Module):
def __init__(self) -> None:
super().__init__()
self.lm_loss = ParallelCrossEntropyLoss()
def forward(self, logits, lm_labels, loss_mask):
lm_loss = self.lm_loss(logits, lm_labels)
loss_mask = loss_mask.to_global(placement=lm_loss.placement)
loss_mask = loss_mask.float()
denominator = loss_mask.sum().to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
)
lm_loss = flow._C.amp_white_identity(lm_loss)
lm_loss = flow._C.amp_black_identity(lm_loss)
masked_lm_loss = flow.sum(lm_loss.view(-1) * loss_mask.view(-1)) / denominator
masked_lm_loss = masked_lm_loss.to_global(
sbp=dist.get_nd_sbp([flow.sbp.partial_sum, flow.sbp.broadcast])
)
return {"masked_lm_loss": masked_lm_loss}
class T5ForPreTraining(flow.nn.Module):
"""
T5 Model with classification head on top.
"""
def __init__(self, cfg) -> None:
super().__init__()
self.t5_model = T5Model(cfg)
self.loss_func = T5Loss()
def set_cache(self, encoder_states, past_key_values):
self.t5_model.set_cache(encoder_states, past_key_values)
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
lm_labels=None,
loss_mask=None,
use_cache=False,
):
"""
Args:
encoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for encoder.
decoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for decoder.
encoder_attn_mask (flow.BoolTensor):
Mask for encoder to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on subsequent token indices.
Mask values have the same meaning as encoder_attn_mask.
encoder_decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on encoder padded token indices.
Mask values have the same meaning as encoder_attn_mask.
lm_labels (flow.LongTensor, optional): Labels for computing the masked
language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`.
None for evaluating.
loss_mask (flow.BoolTensor, optional):
Mask to avoid performing loss computing on ignored tokens.
Tokens with indices set to `-1` are ignored (masked), the loss is only computed
for the tokens with labels in `[0, ..., config.vocab_size]`.
None for evaluating.
use_cache (bool, optional):
It will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"masked_lm_loss": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
logits = self.t5_model(
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
use_cache=use_cache,
)
if lm_labels is not None:
lm_loss = self.loss_func(logits, lm_labels, loss_mask)
return lm_loss
else:
return {
"prediction_scores": logits,
}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.t5_model.encoder.final_layernorm, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
if isinstance(module_block.origin, T5Embedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, ExtendedMask):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.origin, LMLogits):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
elif isinstance(module_block.origin, T5Loss):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.t5_model.encoder.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(model.t5_model.encoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.encoder.final_layernorm.layer_idx),
)
model.t5_model.decoder.final_layernorm.config.set_stage(
dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx),
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), T5Embedding):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), ExtendedMask):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
elif isinstance(module_block.to(nn.Module), LMLogits):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
elif isinstance(module_block.to(nn.Module), T5Loss):
module_block.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.t5_model.encoder.final_layernorm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(model.t5_model.encoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.encoder.final_layernorm.layer_idx),
)
model.t5_model.decoder.final_layernorm.to(nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx),
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .graph_base import GraphBase
from .weight_init import init_method_normal, scaled_init_method_normal
from .model_loader.base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
from .model_loader.bert_loader import BertLoaderHuggerFace, BertLoaderLiBai
from .model_loader.roberta_loader import RobertaLoaderHuggerFace, RobertaLoaderLiBai
from .model_loader.gpt_loader import GPT2LoaderHuggerFace, GPT2LoaderLiBai
from .model_loader.swin_loader import SwinLoaderHuggerFace, SwinLoaderLiBai
from .model_loader.swinv2_loader import SwinV2LoaderHuggerFace, SwinV2LoaderLiBai
from .model_loader.vit_loader import ViTLoaderHuggerFace, ViTLoaderLiBai
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import oneflow as flow
from oneflow import nn
from libai.layers import TransformerLayer
from libai.utils import distributed as dist
logger = logging.getLogger(__name__)
class GraphBase(nn.Graph):
def __init__(
self,
model: nn.Module,
optimizer: flow.optim.Optimizer = None,
lr_scheduler: flow.optim.lr_scheduler = None,
fp16=False,
activation_checkpoint=False,
grad_acc_steps=1,
zero_optim=False,
zero_stage=0,
is_train=True,
auto_parallel_conf=None,
):
super().__init__()
self.model = model
self.is_train = is_train
if is_train:
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
if fp16:
self.config.enable_amp(True)
grad_scaler = flow.amp.GradScaler(
init_scale=65536.0 * dist.get_data_parallel_size(),
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
)
self.set_grad_scaler(grad_scaler)
if grad_acc_steps > 1:
self.config.set_gradient_accumulation_steps(grad_acc_steps)
if activation_checkpoint:
self.set_activation_checkpoint()
if zero_optim:
self.config.enable_zero(True, stage=zero_stage)
self.set_pipeline_stage_id()
self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_model_update_ops(True)
self.config.allow_fuse_cast_scale(True)
# Enable cuda stream for computation and communication as the same stream.
# This will reduce memory when using model parallelism.
dist_util = dist.get_dist_util()
if dist_util.is_tensor_model_parallel() or dist_util.is_pipeline_model_parallel():
flow.boxing.nccl.enable_use_compute_stream(True)
# auto_parallel
if auto_parallel_conf is not None and auto_parallel_conf.enabled:
try:
self.config.enable_auto_parallel(True)
self.config.enable_auto_parallel_ignore_user_sbp_config(
auto_parallel_conf.enable_auto_parallel_ignore_user_sbp_config
)
self.config.set_auto_parallel_computation_cost_ratio(0.05)
self.config.set_auto_parallel_wait_time(1.65e4)
self.config.enable_auto_parallel_trunk_algo(auto_parallel_conf.trunk_algo)
self.config.enable_auto_parallel_sbp_collector(auto_parallel_conf.sbp_collector)
except RuntimeWarning:
import warnings
warnings.warn(
"The version of oneflow don't support auto_parallel.\n"
"Please reinstall the oneflow nightly:\n"
"python3 -m pip install --pre oneflow -f https://staging.oneflow.info/branch/master/[PLATFORM]" # noqa
)
def build(self, **kwargs):
if self.is_train:
logger.info(
"Start compiling the train graph which may take some time. "
"Please wait for a moment ..."
)
loss_dict = self.model(**kwargs)
losses = sum(v for k, v in loss_dict.items() if "loss" in k)
losses.backward()
return loss_dict
else:
logger.info(
"Start compiling the eval graph which may take some time. "
"Please wait for a moment ..."
)
return self.model(**kwargs)
def set_activation_checkpoint(self):
if hasattr(self.model, "origin"):
if hasattr(type(self.model.origin), "set_activation_checkpoint"):
type(self.model.origin).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
else:
if hasattr(type(self.model.to(nn.Module)), "set_activation_checkpoint"):
type(self.model.to(nn.Module)).set_activation_checkpoint(self.model)
else:
for module_block in self.model.modules():
if isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(nn.graph.GraphModule).activation_checkpointing = True
def set_pipeline_stage_id(self):
if hasattr(self.model, "origin"):
if hasattr(type(self.model.origin), "set_pipeline_stage_id"):
type(self.model.origin).set_pipeline_stage_id(self.model)
else:
if hasattr(type(self.model.to(nn.Module)), "set_pipeline_stage_id"):
type(self.model.to(nn.Module)).set_pipeline_stage_id(self.model)
## Introduction
Here are the Weight Loaders currently supported in LiBai. You can use them to load the models in LiBai and the models stored on the huggingface.
## Weight Loader On LiBai
- [BERT Loader](./bert_loader.py)
- [RoBERTa Loader](./roberta_loader.py)
- [GPT2 Loader](./gpt_loader.py)
- [MT5 Loader](../../../../projects/MT5/utils/mt5_loader.py)
- [SWIN Loader](./swin_loader.py)
- [SWIN2 Loader](./swinv2_loader.py)
- [VIT Loader](./vit_loader.py)
## How To Use
We can easily load pretrained BERT as following:
```python
import libai
from libai.models.utils import BertLoaderHuggerFace, BertLoaderLiBai
from configs.common.models.bert import cfg
# load huggingface weight
loader = BertLoaderHuggerFace(
model=libai.models.BertModel,
libai_cfg=cfg,
pretrained_model_path="path/to/huggingface_pretrained_model_directory",
hidden_dropout_prob=0,
apply_residual_post_layernorm=True
)
bert = loader.load()
# load libai weight
loader = BertLoaderLiBai(
model=libai.models.BertModel,
libai_cfg=cfg,
pretrained_model_path='path/to/libai_pretrained_model_directory',
hidden_dropout_prob=0,
)
bert = loader.load()
```
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import logging
import os
import omegaconf
import oneflow as flow
from termcolor import colored
import libai.utils.distributed as dist
from libai.config import LazyCall
from libai.models.build import build_model
logger = logging.getLogger(__name__)
WEIGHTS_NAME_PT = "pytorch_model.bin"
CONFIG_NAME = "config.json"
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
"""load state dict into model
Args:
model_to_load (nn.Module): Model to be loaded.
state_dict (OrderedDict): State dict of pretrained model.
start_prefix (str): Start prefix.
Returns:
list: error message about loading.
"""
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model_to_load, prefix=start_prefix)
return error_msgs
class ModelLoader(object):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
"""Class used to load the [`transformers`](https://huggingface.co/models) pretrained model
or `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
self.model = model
self.libai_cfg = libai_cfg
self.pretrained_model_path = pretrained_model_path
self.kwargs = kwargs
self.output_loading_info = kwargs.pop("output_loading_info", False)
def _state_dict_to_global(self, flow_state_dict=None, mode="libai"):
"""Tensor in OneFlow state dict to global according to model's sbp and placement.
Args:
flow_state_dict (OrderedDict): State dict of OneFlow's pretrained model.
"""
assert mode in ["libai", "pytorch"], f"not support for mode {mode}"
if mode == "libai" or dist.is_main_process():
prefix = self.base_model_prefix_2
# Checkpoint
has_prefix_module = any(
s.startswith(self.base_model_prefix_2) for s in flow_state_dict.keys()
)
# Module
expects_prefix_module = any(
s.startswith(prefix) for s in self.model.state_dict().keys()
)
start_prefix = "" if has_prefix_module else prefix + "."
loaded_keys = [start_prefix + key for key in flow_state_dict.keys()]
else:
prefix, has_prefix_module, expects_prefix_module, loaded_keys = [None] * 4
flow_state_dict = collections.OrderedDict()
prefix = dist.broadcast_py_object(prefix, src=0)
has_prefix_module = dist.broadcast_py_object(has_prefix_module, src=0)
expects_prefix_module = dist.broadcast_py_object(expects_prefix_module, src=0)
loaded_keys = dist.broadcast_py_object(loaded_keys, src=0)
# to global
for key, value in self.model.state_dict().items():
if not expects_prefix_module:
key = prefix + "." + key
if key in loaded_keys:
if not has_prefix_module:
key = ".".join(key.split(".")[1:])
if mode == "pytorch":
flow_state_dict[key] = flow.to_global(
flow_state_dict[key] if dist.is_main_process() else flow.Tensor(None),
sbp=flow.sbp.broadcast,
placement=flow.placement("cpu", ranks=[0]),
)
flow_state_dict[key] = flow.to_global(
flow_state_dict[key],
sbp=value.sbp,
placement=flow.placement("cpu", ranks=list(value.placement.ranks)),
)
return flow_state_dict
def _load_pretrained_model(
self,
model,
state_dict,
pretrained_model_path,
ignore_mismatched_sizes=False,
):
"""Load pretrained model.
Args:
model (libai.models): The model to be loaded.
state_dict (OrderedDict): state dict.
loaded_keys (list): keys of state dict.
pretrained_model_path (str): pretrained modelE path.
ignore_mismatched_sizes (bool):
Whether or not to raise an error if some of the weights
from the checkpoint do not have the same size as the
weights of the model, defaults to `False`.
"""
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = self.base_model_prefix_2
loaded_keys = state_dict.keys()
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [
".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys
]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
start_prefix = ""
model_to_load = model
if (
len(self.base_model_prefix_2) > 0
and not hasattr(model, self.base_model_prefix_2)
and has_prefix_module
):
start_prefix = self.base_model_prefix_2 + "."
if (
len(self.base_model_prefix_2) > 0
and hasattr(model, self.base_model_prefix_2)
and not has_prefix_module
):
model_to_load = getattr(model, self.base_model_prefix_2)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError("The state dict of the model you are loading is corrupted.")
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(
checkpoint_key,
state_dict[checkpoint_key].shape,
model_state_dict[model_key].shape,
)
)
del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if dist.get_local_rank() == 0:
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(
f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_path} "
"were not used when "
f"initializing {model.__class__.__name__}:\n {unexpected_keys}\n"
)
else:
logger.info(
f"All model checkpoint weights were used when initializing "
f"{model.__class__.__name__}.\n"
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized "
f"from the model checkpoint at {pretrained_model_path}:\n "
f"{missing_keys} \n"
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized "
f"from the model checkpoint at {pretrained_model_path}.\n"
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2}"
"in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized"
f"from the model checkpoint at {pretrained_model_path} "
f"and are newly initialized because the shapes did not"
f"match:\n{mismatched_warning}\n"
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
class ModelLoaderLiBai(ModelLoader):
"""Class used to load `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = None # prefix in LiBai
def _load_flow_state_dict(self, state_dict_file):
# load oneflow_model
state_dict = flow.load(state_dict_file, global_src_rank=0)
return state_dict
def load(self):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from libai.config.configs.common.models.bert import cfg
>>> from model_loader import BertLoaderLiBai
>>> loder = BertLoaderLiBai(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loder.load()
"""
if dist.is_main_process():
assert os.path.isdir(
self.pretrained_model_path
), f"{self.pretrained_model_path} must be a directory"
flow_state_dict = self._load_flow_state_dict(self.pretrained_model_path)
# Instance model
if isinstance(self.model, omegaconf.dictconfig.DictConfig):
self.model.cfg = self.libai_cfg
self.model = build_model(self.model)
else:
self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))
# State_dict to global
self._state_dict_to_global(flow_state_dict, mode="libai")
# Load
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
error_msgs,
) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
if self.output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
class ModelLoaderHuggerFace(ModelLoader):
"""Class used to load the [`transformers`](https://huggingface.co/models)
pretrained model.
"""
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_1 = None # prefix in Transformers
self.base_model_prefix_2 = None # prefix in LiBai
self.origin_libai_cfg = copy.deepcopy(self.libai_cfg)
self.changed_keys = set() # Store the changed configuration
def _convert_tensor(self, tensor):
"""Convert PyTorch tensor to OneFlow tensor.
Args:
tensor (torch.Tensor): The source tensor.
Returns:
flow.Tensor: The target tensor.
"""
tensor = tensor.float()
return flow.Tensor(tensor.detach().cpu().numpy())
def _convert_tensors(self, torch_state_dict):
for k, v in torch_state_dict.items():
torch_state_dict[k] = self._convert_tensor(v)
return torch_state_dict
def _fix_key(self, state_dict):
"""Fix the key in state dict: Convert "gamma" to "weight" and "beta" to "bias".
Args:
state_dict (OrderedDict): state dict of pretrained model.
Returns:
OrderedDict: State dict after fix key.
"""
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return state_dict
def _fix_qkv_ordering(
self, qkv, head_size, num_heads, hidden_size=None, checkpoint_version=0.0
):
# TODO(xzp): Different versions checkpoint
hidden_size = (head_size * num_heads) if hidden_size is None else hidden_size
num_of_qkv = qkv.shape[0] // (head_size * num_heads)
mode = "weight" if qkv.ndim > 1 else "bias"
if mode == "weight":
qkv = qkv.view([num_of_qkv, num_heads, head_size, hidden_size])
qkv = (
qkv.permute(1, 0, 2, 3)
.contiguous()
.view(num_of_qkv * head_size * num_heads, hidden_size)
)
elif mode == "bias":
qkv = qkv.view(num_of_qkv, num_heads, head_size)
qkv = qkv.permute(1, 0, 2).contiguous().view(-1)
return qkv
def _convert_state_dict(self, flow_state_dict, cfg):
"""A function used to convert the checkpoint file of Huggingface to LiBai.
Args:
torch_state_dict (OrderedDict): torch state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
raise NotImplementedError("_convert_state_dict not implemented")
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
raise NotImplementedError("_load_config_from_json not implemented")
def _load_torch_state_dict(self, state_dict_file):
try:
import torch
except ImportError:
raise ImportError("Load torch state dict need torch.")
# load pytorch_model.bin
state_dict = torch.load(state_dict_file, map_location="cpu")
return state_dict
def _update_cfg(self, keys_libai, value_target):
"""Update the libai_cfg according to target_cfg.
Args:
keys_libai (str): The key of libai_cfg.
value_target (int | float): The value of target_cfg.
"""
if keys_libai not in self.libai_cfg.keys():
return
if self.libai_cfg[keys_libai] != value_target:
self.libai_cfg[keys_libai] = value_target
def _update_cfg_log(self):
if dist.get_local_rank() == 0:
for key in sorted(self.libai_cfg):
if self.origin_libai_cfg[key] == self.libai_cfg[key]:
continue
self.changed_keys.add(key)
temp_key = colored(key, "yellow")
logger.info(
f"changed libai model cfg {temp_key} : "
f"{self.origin_libai_cfg[key]} -> {self.libai_cfg[key]} "
)
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)
if dist.get_pipeline_parallel_size() > 1:
logger.warning(
colored(
"If you use pipeline parallel, please "
"confirm the setting of `train.dist.pipeline_num_layers` \n",
"red",
)
)
def load(self):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from configs.common.models.bert import cfg
>>> from libai.models.utils import BertLoaderHugger
>>> loader = BertLoaderHugger(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loader.load()
"""
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
if os.path.isfile(os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)):
model_file = os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME_PT} found"
f"in directory {self.pretrained_model_path}."
)
# config file
if os.path.isfile(os.path.join(self.pretrained_model_path, CONFIG_NAME)):
config_file = os.path.join(self.pretrained_model_path, CONFIG_NAME)
# Load config and update config.
self._load_config_from_json(config_file)
else:
import warnings
warnings.warn(
f"Error no file named {CONFIG_NAME} found in directory"
f"{self.pretrained_model_path}",
RuntimeWarning,
)
else:
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")
logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_file)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
flow_state_dict = self._convert_state_dict(torch_state_dict, self.libai_cfg)
else:
flow_state_dict = None
self.libai_cfg = dist.broadcast_py_object(self.libai_cfg, src=0)
# Instance model
logger.info("building LiBai model...")
if isinstance(self.model, omegaconf.dictconfig.DictConfig):
self.model.cfg = self.libai_cfg
self.model = build_model(self.model)
else:
self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg))
# State_dict to global
logger.info("transfering state_dict local to global...")
flow_state_dict = self._state_dict_to_global(flow_state_dict, mode="pytorch")
logger.info("loading model weights into LiBai...")
# Load
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
error_msgs,
) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
if self.output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class BertLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is BERT's prefix in Transformers.
base_model_prefix_2 is BERT's prefix in LiBai."""
self.base_model_prefix_1 = "bert"
self.base_model_prefix_2 = "bert"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
layers = cfg.get("hidden_layers")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix = "bert." if has_prefix else ""
index_idx = 3 if has_prefix else 2
qkv_idx = 6 if has_prefix else 5
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert bert's embedding layers
if "embeddings" in key:
if "word_embeddings" in key:
new_key = key.replace("word_embeddings", "vocab_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "token_type_embeddings" in key:
new_key = key.replace("token_type_embeddings", "tokentype_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.weight" in key:
new_key = prefix + "encoders.0.input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.bias" in key:
new_key = prefix + "encoders.0.input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict[key]
# Convert bert's attention layers
elif "attention" in key:
if "self" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key.replace(key.split(".")[qkv_idx], "query").replace(
key.split(".")[qkv_idx + 1], "weight"
)
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads)
qkv_b = self._fix_qkv_ordering(qkv_b, head_size, num_heads)
new_key = (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = prefix + "encoders." + index + ".self_attention.query_key_value.bias"
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
index = key.split(".")[index_idx]
if "dense" in key:
if "weight" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm" in key:
if "weight" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert bert's intermediate layers
elif "intermediate" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
if "weight" in key:
w = key
b = key.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert bert's output layers
elif "output" in key:
index = key.split(".")[index_idx]
if "dense.weight" in key:
if (
prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "LayerNorm.weight" in key:
if (
prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
if index == str(layers - 1):
new_key = prefix + "final_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
continue
new_key = prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert bert's pooler layers
elif "pooler" in key:
if "weight" in key:
new_key = prefix + "pooler.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "pooler.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert cls_head layers
elif "cls" in key:
if "predictions.bias" in key:
new_key = "cls_head.lm_logits.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "dense.weight" in key:
new_key = "cls_head.predictions.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "dense.bias" in key:
new_key = "cls_head.predictions.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.weight" in key:
new_key = "cls_head.predictions.layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.bias" in key:
new_key = "cls_head.predictions.layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "seq_relationship" in key:
new_key = key.replace("cls", "cls_head")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("vocab_size", cfg_dict["vocab_size"])
self._update_cfg("hidden_size", cfg_dict["hidden_size"])
self._update_cfg("hidden_layers", cfg_dict["num_hidden_layers"])
self._update_cfg("num_attention_heads", cfg_dict["num_attention_heads"])
self._update_cfg("intermediate_size", cfg_dict["intermediate_size"])
self._update_cfg("hidden_dropout_prob", cfg_dict["hidden_dropout_prob"])
self._update_cfg("attention_probs_dropout_prob", cfg_dict["attention_probs_dropout_prob"])
self._update_cfg("max_position_embeddings", cfg_dict["max_position_embeddings"])
self._update_cfg("num_tokentypes", cfg_dict["type_vocab_size"])
self._update_cfg("initializer_range", cfg_dict["initializer_range"])
self._update_cfg("layernorm_eps", cfg_dict["layer_norm_eps"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
# use original BERT residual connection ordering
self.libai_cfg.apply_residual_post_layernorm = True
self._update_cfg_log()
class BertLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "bert"
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class GPT2LoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is GPT's prefix in Transformers.
base_model_prefix_2 is GPT's prefix in LiBai."""
self.base_model_prefix_1 = "transformer"
self.base_model_prefix_2 = "GPT_model"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
old_keys = list(oneflow_state_dict.keys())
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix1 = self.base_model_prefix_1 + "." if has_prefix else ""
prefix2 = "GPT_model.transformer."
layer_idx = 2 if has_prefix else 1
# Convert Embedding layers.
new_key = "GPT_model.embeddings.token_embeddings.weight"
old_keys.remove(prefix1 + "wte.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "wte.weight")
new_key = "GPT_model.embeddings.position_embeddings.weight"
old_keys.remove(prefix1 + "wpe.weight")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(prefix1 + "wpe.weight")
for key in old_keys:
keys = key.split(".")
if layer_idx >= len(keys):
continue
layer = keys[layer_idx]
# Convert transformer layers.
if "h." in key:
if "ln_1" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".input_layernorm.weight"
else:
new_key = prefix2 + "layers." + layer + ".input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "ln_2" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".post_attention_layernorm.weight"
else:
new_key = prefix2 + "layers." + layer + ".post_attention_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "attn" in key:
if "c_attn" in key:
if "weight" in key:
new_key = (
prefix2
+ "layers."
+ layer
+ ".self_attention.query_key_value.weight"
)
else:
new_key = (
prefix2 + "layers." + layer + ".self_attention.query_key_value.bias"
)
qkv = oneflow_state_dict.pop(key)
if qkv.ndim > 1:
qkv = qkv.transpose(1, 0)
qkv = self._fix_qkv_ordering(qkv, head_size, num_heads)
oneflow_state_dict[new_key] = qkv
elif "c_proj" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".self_attention.dense.weight"
elif "bias" in key:
new_key = prefix2 + "layers." + layer + ".self_attention.dense.bias"
value = oneflow_state_dict.pop(key)
if value.ndim > 1:
value = value.transpose(1, 0)
oneflow_state_dict[new_key] = value
elif "mlp" in key:
if "c_fc" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_h_to_4h.weight"
elif "bias" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_h_to_4h.bias"
value = oneflow_state_dict.pop(key)
if value.ndim > 1:
value = value.transpose(1, 0)
oneflow_state_dict[new_key] = value
elif "c_proj" in key:
if "weight" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_4h_to_h.weight"
elif "bias" in key:
new_key = prefix2 + "layers." + layer + ".mlp.dense_4h_to_h.bias"
value = oneflow_state_dict.pop(key)
if value.ndim > 1:
value = value.transpose(1, 0)
oneflow_state_dict[new_key] = value
elif "ln_f" in key:
if "weight" in key:
new_key = prefix2 + "layernorm_f.weight"
elif "bias" in key:
new_key = prefix2 + "layernorm_f.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("hidden_layers", cfg_dict["n_layer"])
self._update_cfg("hidden_size", cfg_dict["n_embd"])
self._update_cfg("num_attention_heads", cfg_dict["n_head"])
self._update_cfg("max_seq_length", cfg_dict["n_positions"])
self._update_cfg("embedding_dropout_prob", cfg_dict["embd_pdrop"])
self._update_cfg("attention_dropout_prob", cfg_dict["attn_pdrop"])
self._update_cfg("output_dropout_prob", cfg_dict["resid_pdrop"])
self._update_cfg("layernorm_epsilon", cfg_dict["layer_norm_epsilon"])
self._update_cfg("vocab_size", cfg_dict["vocab_size"])
self._update_cfg("initializer_range", cfg_dict["initializer_range"])
self._update_cfg(
"ffn_hidden_size",
cfg_dict.get("n_inner")
if cfg_dict.get("n_inner") is not None
else 4 * self.libai_cfg["hidden_size"],
)
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class GPT2LoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "GPT_model"
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from .bert_loader import BertLoaderHuggerFace, BertLoaderLiBai
class RobertaLoaderHuggerFace(BertLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is RoBERTa's prefix in Transformers,
base_model_prefix_2 is RoBERTa's prefix in LiBai."""
self.base_model_prefix_1 = "roberta"
self.base_model_prefix_2 = "roberta"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# Get configs
num_heads = cfg.get("num_attention_heads")
hidden_size = cfg.get("hidden_size")
layers = cfg.get("hidden_layers")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix = "roberta." if has_prefix else ""
index_idx = 3 if has_prefix else 2
qkv_idx = 6 if has_prefix else 5
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert roberta's embedding layers
if "embeddings" in key:
if "word_embeddings" in key:
new_key = key.replace("word_embeddings", "vocab_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "token_type_embeddings" in key:
new_key = key.replace("token_type_embeddings", "tokentype_embeddings")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.weight" in key:
new_key = prefix + "encoders.0.input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm.bias" in key:
new_key = prefix + "encoders.0.input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict[key]
# Convert roberta's attention layers
elif "attention" in key:
if "self" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key.replace(key.split(".")[qkv_idx], "query").replace(
key.split(".")[qkv_idx + 1], "weight"
)
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads)
qkv_b = self._fix_qkv_ordering(qkv_b, head_size, num_heads)
new_key = (
prefix + "encoders." + index + ".self_attention.query_key_value.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = prefix + "encoders." + index + ".self_attention.query_key_value.bias"
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
index = key.split(".")[index_idx]
if "dense" in key:
if "weight" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "encoders." + index + ".self_attention.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "LayerNorm" in key:
if "weight" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = (
prefix + "encoders." + index + ".post_attention_layernorm.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert roberta's intermediate layers
elif "intermediate" in key:
index = key.split(".")[index_idx]
if (
prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
if "weight" in key:
w = key
b = key.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_h_to_4h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert roberta's output layers
elif "output" in key:
index = key.split(".")[index_idx]
if "dense.weight" in key:
if (
prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = prefix + "encoders." + index + ".mlp.dense_4h_to_h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "LayerNorm.weight" in key:
if (
prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
if index == str(layers - 1):
new_key = prefix + "final_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
continue
new_key = prefix + "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
# Convert roberta's pooler layers
elif "pooler" in key:
if "weight" in key:
new_key = prefix + "pooler.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = prefix + "pooler.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert lm_head layers
elif "lm_head" in key:
if "layer_norm.weight" in key:
new_key = "lm_head.layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layer_norm.bias" in key:
new_key = "lm_head.layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "seq_relationship" in key:
new_key = key.replace("cls", "cls_head")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "lm_head.bias" in key:
new_key = "lm_head.lm_logits.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
class RobertaLoaderLiBai(BertLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "roberta"
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class SwinLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is SWIN's prefix in Transformers.
base_model_prefix_2 is SWIN's prefix in LiBai."""
self.base_model_prefix_1 = "swin"
self.base_model_prefix_2 = ""
def _convert_state_dict(self, flow_state_dict, cfg=None):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
index_idx_1 = 3 if has_prefix else 2
index_idx_2 = 5 if has_prefix else 4
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert swin's embedding layers
if "embeddings" in key:
if "patch_embeddings.projection" in key:
if "weight" in key:
new_key = "patch_embed.proj.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "patch_embed.proj.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if "weight" in key:
new_key = "patch_embed.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "patch_embed.norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swin's layernorm layers
elif "layernorm_before" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layernorm_after" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swin's attention layers
elif "attention" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "self" in key:
if (
"relative_position_bias_table" in key
): # convert relative_position_bias_table but not index
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_bias_table"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "relative_position_index" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_index"
)
oneflow_state_dict.pop(key)
else:
if (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
if "dense" in key:
if "weight" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "intermediate" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = key.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "output" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "dense.weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "downsample" in key:
index_layer = key.split(".")[index_idx_1]
if "reduction.weight" in key:
new_key = "layers." + index_layer + ".downsample.reduction.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if (
"layers." + index_layer + ".downsample.norm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = "layers." + index_layer + ".downsample.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "layernorm" in key:
if "weight" in key:
new_key = "norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "classifier" in key:
if "weight" in key:
new_key = "head.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "head.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("img_size", cfg_dict["image_size"])
self._update_cfg("patch_size", cfg_dict["patch_size"])
self._update_cfg("embed_dim", cfg_dict["embed_dim"])
self._update_cfg("depths", cfg_dict["depths"])
self._update_cfg("num_heads", cfg_dict["num_heads"])
self._update_cfg("window_size", cfg_dict["window_size"])
self._update_cfg("mlp_ratio", cfg_dict["mlp_ratio"])
self._update_cfg("qkv_bias", cfg_dict["qkv_bias"])
self._update_cfg("drop_path_rate", cfg_dict["drop_path_rate"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class SwinLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = ""
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class SwinV2LoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is SWINV2's prefix in Transformers.
base_model_prefix_2 is SWINV2's prefix in LiBai."""
self.base_model_prefix_1 = "swinv2"
self.base_model_prefix_2 = ""
def _convert_state_dict(self, flow_state_dict, cfg=None):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
index_idx_1 = 3 if has_prefix else 2
index_idx_2 = 5 if has_prefix else 4
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert swinv2's embedding layers
if "embeddings" in key:
if "patch_embeddings.projection" in key:
if "weight" in key:
new_key = "patch_embed.proj.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = "patch_embed.proj.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if "weight" in key:
new_key = "patch_embed.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = "patch_embed.norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swinv2's layernorm layers
elif "layernorm_before" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm1.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layernorm_after" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "layers." + index_layer + ".blocks." + index_block + ".norm2.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert swinv2's attention layers
elif "attention" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "self" in key:
if (
"relative_position_bias_table" in key
): # convert relative_position_bias_table but not index
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_bias_table"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "relative_position_index" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.relative_position_index"
)
oneflow_state_dict.pop(key)
elif "continuous_position_bias_mlp" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.cpb_mlp"
+ ".0.weight"
) in oneflow_state_dict.keys():
continue
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.cpb_mlp"
)
m_1_w = key
m_1_b = key.replace(".0.weight", ".0.bias")
m_2_w = key.replace(".0.weight", ".2.weight")
oneflow_state_dict[new_key + ".0.weight"] = oneflow_state_dict.pop(m_1_w)
oneflow_state_dict[new_key + ".0.bias"] = oneflow_state_dict.pop(m_1_b)
oneflow_state_dict[new_key + ".2.weight"] = oneflow_state_dict.pop(m_2_w)
elif "logit_scale" in key:
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.logit_scale"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)[None, ...]
else:
if (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.qkv.weight"
)
oneflow_state_dict[new_key] = qkv_w
new_key = (
"layers." + index_layer + ".blocks." + index_block + ".attn.q_bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(q_b)
new_key = new_key.replace("q_bias", "v_bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(v_b)
elif "output" in key:
if "dense" in key:
if "weight" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".attn.proj.bias"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "intermediate" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = key.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "output" in key:
index_layer = key.split(".")[index_idx_1]
index_block = key.split(".")[index_idx_2]
if "dense.weight" in key:
if (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = (
"layers."
+ index_layer
+ ".blocks."
+ index_block
+ ".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "downsample" in key:
index_layer = key.split(".")[index_idx_1]
if "reduction.weight" in key:
new_key = "layers." + index_layer + ".downsample.reduction.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "norm" in key:
if (
"layers." + index_layer + ".downsample.norm.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = "layers." + index_layer + ".downsample.norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "layernorm" in key:
if "weight" in key:
new_key = "norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "classifier" in key:
if "weight" in key:
new_key = "head.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "head.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("img_size", cfg_dict["image_size"])
self._update_cfg("patch_size", cfg_dict["patch_size"])
self._update_cfg("embed_dim", cfg_dict["embed_dim"])
self._update_cfg("depths", cfg_dict["depths"])
self._update_cfg("num_heads", cfg_dict["num_heads"])
self._update_cfg("window_size", cfg_dict["window_size"])
self._update_cfg("mlp_ratio", cfg_dict["mlp_ratio"])
self._update_cfg("qkv_bias", cfg_dict["qkv_bias"])
self._update_cfg("drop_path_rate", cfg_dict["drop_path_rate"])
self._update_cfg("pretrained_window_sizes", cfg_dict["pretrained_window_sizes"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class SwinV2LoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = ""
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import oneflow as flow
from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai
class ViTLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is ViT's prefix in Transformers.
base_model_prefix_2 is ViT's prefix in LiBai."""
self.base_model_prefix_1 = "vit"
self.base_model_prefix_2 = ""
def _convert_state_dict(self, flow_state_dict, cfg=None):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
# Get configs
num_heads = cfg.get("num_heads")
hidden_size = cfg.get("embed_dim")
head_size = int(hidden_size / num_heads)
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
index_idx = 3 if has_prefix else 2
old_keys = oneflow_state_dict.keys()
for key in list(old_keys):
# Convert vit's embedding layers
if "embeddings" in key:
if "cls_token" in key:
new_key = "cls_token"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "position_embeddings" in key:
new_key = "pos_embed"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "patch_embeddings.projection" in key:
if "weight" in key:
new_key = "patch_embed.proj.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "patch_embed.proj.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert vit's layernorm layers
elif "layernorm_before" in key:
index_block = key.split(".")[index_idx]
if "weight" in key:
new_key = "blocks." + index_block + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "blocks." + index_block + ".input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "layernorm_after" in key:
index_block = key.split(".")[index_idx]
if "weight" in key:
new_key = "blocks." + index_block + ".post_attention_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "blocks." + index_block + ".post_attention_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
# Convert vit's attention layers
elif "attention" in key:
index_block = key.split(".")[index_idx]
if "attention.attention" in key:
if (
"blocks." + index_block + ".self_attention.query_key_value.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")
qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)
qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads)
qkv_b = self._fix_qkv_ordering(qkv_b, head_size, num_heads)
new_key = "blocks." + index_block + ".self_attention.query_key_value.weight"
oneflow_state_dict[new_key] = qkv_w
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = qkv_b
elif "output" in key:
if "dense" in key:
if "weight" in key:
new_key = "blocks." + index_block + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = "blocks." + index_block + ".self_attention.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "intermediate" in key:
index_block = key.split(".")[index_idx]
if "weight" in key:
if (
"blocks." + index_block + ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = key.replace("weight", "bias")
new_key = "blocks." + index_block + ".mlp.dense_h_to_4h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "output" in key:
index_block = key.split(".")[index_idx]
if "dense.weight" in key:
if (
"blocks." + index_block + ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = "blocks." + index_block + ".mlp.dense_4h_to_h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)
elif "layernorm" in key:
if "weight" in key:
new_key = "norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "classifier" in key:
if "weight" in key:
new_key = "head.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "head.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
# update libai_cfg by config.json
self._update_cfg("img_size", cfg_dict["image_size"])
self._update_cfg("patch_size", cfg_dict["patch_size"])
self._update_cfg("in_chans", cfg_dict["num_channels"])
self._update_cfg("embed_dim", cfg_dict["hidden_size"])
self._update_cfg("depth", cfg_dict["num_hidden_layers"])
self._update_cfg("num_heads", cfg_dict["num_attention_heads"])
self._update_cfg("attn_drop_rate", cfg_dict["attention_probs_dropout_prob"])
self._update_cfg("drop_rate", cfg_dict["hidden_dropout_prob"])
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class ViTLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = ""
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow.nn as nn
def init_method_normal(sigma, mean=0.0):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return nn.init.normal_(tensor, mean=mean, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers, mean=0.0):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return nn.init.normal_(tensor, mean=mean, std=std)
return init_
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn as nn
from flowvision.layers.weight_init import trunc_normal_
import libai.utils.distributed as dist
from libai.config.config import configurable
from libai.layers import LayerNorm, Linear, PatchEmbedding, TransformerLayer
class VisionTransformer(nn.Module):
"""Vision Transformer in LiBai.
LiBai's implementation of:
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
<https://arxiv.org/abs/2010.11929>`_
Args:
img_size (int, tuple(int)): input image size
patch_size (int, tuple(int)): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
num_classes (int): number of classes for classification head
loss_func (callable, optional): loss function for computing the total loss
between logits and labels
"""
@configurable
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4.0,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
num_classes=1000,
loss_func=None,
):
super().__init__()
self.img_size = img_size
self.num_classes = num_classes
self.patch_embed = PatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
ffn_size = int(embed_dim * mlp_ratio)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(
flow.zeros(
1,
1,
embed_dim,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
)
self.pos_embed = nn.Parameter(
flow.zeros(
1,
num_patches + 1,
embed_dim,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in flow.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
TransformerLayer(
hidden_size=embed_dim,
ffn_hidden_size=ffn_size,
num_attention_heads=num_heads,
attention_dropout_prob=attn_drop_rate,
output_dropout_prob=drop_rate,
drop_path_prob=dpr[i],
layer_idx=i,
)
for i in range(depth)
]
)
self.norm = LayerNorm(embed_dim, layer_idx=-1)
self.head = Linear(embed_dim, num_classes, layer_idx=-1)
# loss func
self.loss_func = nn.CrossEntropyLoss() if loss_func is None else loss_func
# weight init
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
@classmethod
def from_config(cls, cfg):
return {
"img_size": cfg.img_size,
"patch_size": cfg.patch_size,
"in_chans": cfg.in_chans,
"embed_dim": cfg.embed_dim,
"depth": cfg.depth,
"num_heads": cfg.num_heads,
"mlp_ratio": cfg.mlp_ratio,
"drop_rate": cfg.drop_rate,
"attn_drop_rate": cfg.attn_drop_rate,
"drop_path_rate": cfg.drop_path_rate,
"num_classes": cfg.num_classes,
"loss_func": cfg.loss_func,
}
def forward_features(self, x):
# patch embedding
x = self.patch_embed(x)
cls_token = self.cls_token.expand(
x.shape[0], -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
cls_token = cls_token.to_global(sbp=x.sbp, placement=cls_token.placement)
x = flow.cat((cls_token, x), dim=1)
# position embedding
pos_embed = self.pos_embed.expand(x.shape[0], -1, -1)
pos_embed = pos_embed.to_global(sbp=x.sbp, placement=pos_embed.placement)
x = self.pos_drop(x + pos_embed)
# transformer block
x = self.blocks(x)
return x
def forward_head(self, x):
x = self.norm(x)
outcome = x[:, 0]
outcome = self.head(outcome)
return outcome
def forward(self, images, labels=None):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x = self.forward_features(images)
x = self.forward_head(x)
if labels is not None and self.training:
losses = self.loss_func(x, labels)
return {"losses": losses}
else:
return {"prediction_scores": x}
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
# Set pipeline parallelism stage_id
if hasattr(model.pos_embed, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
if isinstance(module_block.origin, PatchEmbedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# Set pos_embed and cls_token stage id
model.pos_embed.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.cls_token.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.norm.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), PatchEmbedding):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# Set pos_embed and cls_token stage id
model.pos_embed.to(flow.nn.graph.GraphTensor).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.cls_token.to(flow.nn.graph.GraphTensor).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.pos_drop.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
model.norm.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.head.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
model.loss_func.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
from libai.config import LazyConfig
from libai.models.utils import GPT2LoaderLiBai
from projects.MagicPrompt.gpt2 import GPTModel
def get_model(config_file):
cfg = LazyConfig.load(config_file)
cfg.model.cfg.pretrained_model_path = None
cfg.dataloader = None
cfg.tokenization = None
print("Building model....")
loader = GPT2LoaderLiBai(GPTModel, cfg.cfg, "/path/to/model")
model = loader.load()
print("Build model finished.")
return model
class gpt2Graph(nn.Graph):
def __init__(self, eager_model):
super().__init__()
self.model = eager_model
def build(
self,
input_ids,
):
out = self.model(
input_ids,
)
return out
if __name__ == "__main__":
model = get_model("projects/MagicPrompt/configs/gpt2_inference.py")
model.eval()
gpt2_graph = gpt2Graph(model)
# Build the static graph model
input_ids = flow.ones(
1, 5, dtype=flow.int64, sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
# check your model.forward is valid
# output = gpt2_graph(
# input_ids
# )
print("Compiling the graph which may make some time, please wait for a moment....")
gpt2_graph._compile(
input_ids,
)
convert_to_onnx_and_check(
gpt2_graph,
external_data=False,
opset=11,
flow_weight_dir=None,
onnx_model_path="./",
dynamic_batch_size=False,
device="gpu_global",
input_tensor_range=[0, 10],
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import List
import numpy as np
import onnxruntime as ort
class OnnxModel:
def __init__(
self,
onnx_filename,
providers: List[str] = None,
ort_optimize: bool = True,
):
ort_sess_opt = ort.SessionOptions()
ort_sess_opt.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
if ort_optimize
else ort.GraphOptimizationLevel.ORT_DISABLE_ALL
)
if providers is None:
if ort.__version__ > "1.9.0":
providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
providers = ["CPUExecutionProvider"]
self.sess = ort.InferenceSession(
onnx_filename, sess_options=ort_sess_opt, providers=providers
)
def forward(self, input_list):
ipt_dict = OrderedDict()
for idx, ipt in enumerate(self.sess.get_inputs()):
ipt_dict[ipt.name] = input_list[idx]
onnx_res = self.sess.run([], ipt_dict)
return onnx_res
if __name__ == "__main__":
onnx_model = OnnxModel("model.onnx")
input_list = [
np.ones((1, 5)).astype(np.int64).astype(np.int64),
]
print(onnx_model.forward(input_list))
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import List
import numpy as np
import onnxruntime as ort
class OnnxModel:
def __init__(
self,
onnx_filename,
providers: List[str] = None,
ort_optimize: bool = True,
):
ort_sess_opt = ort.SessionOptions()
ort_sess_opt.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
if ort_optimize
else ort.GraphOptimizationLevel.ORT_DISABLE_ALL
)
if providers is None:
if ort.__version__ > "1.9.0":
providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
providers = ["CPUExecutionProvider"]
self.sess = ort.InferenceSession(
onnx_filename, sess_options=ort_sess_opt, providers=providers
)
def forward(self, input_list):
ipt_dict = OrderedDict()
for idx, ipt in enumerate(self.sess.get_inputs()):
ipt_dict[ipt.name] = input_list[idx]
onnx_res = self.sess.run([], ipt_dict)
return onnx_res
if __name__ == "__main__":
onnx_model = OnnxModel("model.onnx")
input_list = [
np.ones((1, 5)).astype(np.int64),
np.ones((1, 3)).astype(np.int64),
np.ones((1, 5, 5)).astype(bool),
np.ones((1, 3, 3)).astype(bool),
np.ones((1, 3, 5)).astype(bool),
]
print(onnx_model.forward(input_list))
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
from libai.config import LazyConfig
from projects.MT5.mt5_model import MT5Model
from projects.MT5.utils.mt5_loader import T5LoaderHuggerFace
def get_model(config_file):
cfg = LazyConfig.load(config_file)
cfg.model.cfg.model_type = "mt5"
cfg.model.cfg.pretrained_model_path = None
cfg.dataloader = None
cfg.tokenization = None
print("Building model....")
loader = T5LoaderHuggerFace(MT5Model, cfg.model.cfg, "/path/to/model")
model = loader.load()
print("Build model finished.")
return model
class t5Graph(nn.Graph):
def __init__(self, eager_model):
super().__init__()
self.model = eager_model
def build(
self,
encoder_input_ids,
encoder_attn_mask,
decoder_input_ids,
decoder_attn_mask,
encoder_decoder_attn_mask,
):
out = self.model(
encoder_input_ids,
encoder_attn_mask,
decoder_input_ids,
decoder_attn_mask,
encoder_decoder_attn_mask,
)
return out
if __name__ == "__main__":
model = get_model("projects/MT5/configs/mt5_pretrain.py")
model.eval()
t5_graph = t5Graph(model)
# Build the static graph model
encoder_input_ids = flow.ones(
1, 5, dtype=flow.int64, sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
encoder_attn_mask = flow.ones(
1, 3, dtype=flow.int64, sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
decoder_input_ids = flow.ones(
1,
5,
5,
dtype=flow.bool,
sbp=flow.sbp.broadcast,
placement=flow.placement("cuda", ranks=[0]),
)
decoder_attn_mask = flow.ones(
1,
3,
3,
dtype=flow.bool,
sbp=flow.sbp.broadcast,
placement=flow.placement("cuda", ranks=[0]),
)
encoder_decoder_attn_mask = flow.ones(
1,
3,
5,
dtype=flow.bool,
sbp=flow.sbp.broadcast,
placement=flow.placement("cuda", ranks=[0]),
)
# check your model.forward is valid
# output = t5_graph(
# encoder_input_ids,
# encoder_attn_mask,
# decoder_input_ids,
# decoder_attn_mask,
# encoder_decoder_attn_mask
# )
# print(output)
print("Compiling the graph which may make some time, please wait for a moment....")
t5_graph._compile(
encoder_input_ids,
encoder_attn_mask,
decoder_input_ids,
decoder_attn_mask,
encoder_decoder_attn_mask,
)
convert_to_onnx_and_check(
t5_graph,
external_data=False,
opset=11,
flow_weight_dir=None,
onnx_model_path="./",
dynamic_batch_size=False,
device="gpu_global",
input_tensor_range=[0, 10],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment