Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
BLOOM_oneflow
Commits
9fdb7dab
Commit
9fdb7dab
authored
Mar 30, 2023
by
yuguo960516
Browse files
bloom
parents
Pipeline
#150
failed with stages
in 0 seconds
Changes
332
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5148 additions
and
0 deletions
+5148
-0
libai/models/swin_transformer.py
libai/models/swin_transformer.py
+772
-0
libai/models/swin_transformer_v2.py
libai/models/swin_transformer_v2.py
+900
-0
libai/models/t5_model.py
libai/models/t5_model.py
+518
-0
libai/models/utils/__init__.py
libai/models/utils/__init__.py
+24
-0
libai/models/utils/graph_base.py
libai/models/utils/graph_base.py
+138
-0
libai/models/utils/model_loader/README.md
libai/models/utils/model_loader/README.md
+39
-0
libai/models/utils/model_loader/__init__.py
libai/models/utils/model_loader/__init__.py
+0
-0
libai/models/utils/model_loader/base_loader.py
libai/models/utils/model_loader/base_loader.py
+603
-0
libai/models/utils/model_loader/bert_loader.py
libai/models/utils/model_loader/bert_loader.py
+263
-0
libai/models/utils/model_loader/gpt_loader.py
libai/models/utils/model_loader/gpt_loader.py
+174
-0
libai/models/utils/model_loader/roberta_loader.py
libai/models/utils/model_loader/roberta_loader.py
+226
-0
libai/models/utils/model_loader/swin_loader.py
libai/models/utils/model_loader/swin_loader.py
+298
-0
libai/models/utils/model_loader/swinv2_loader.py
libai/models/utils/model_loader/swinv2_loader.py
+316
-0
libai/models/utils/model_loader/vit_loader.py
libai/models/utils/model_loader/vit_loader.py
+225
-0
libai/models/utils/weight_init.py
libai/models/utils/weight_init.py
+37
-0
libai/models/vision_transformer.py
libai/models/vision_transformer.py
+267
-0
libai/onnx_export/gpt2_to_onnx.py
libai/onnx_export/gpt2_to_onnx.py
+86
-0
libai/onnx_export/onnx_inference/gpt2_onnx_infer.py
libai/onnx_export/onnx_inference/gpt2_onnx_infer.py
+64
-0
libai/onnx_export/onnx_inference/t5_onnx_infer.py
libai/onnx_export/onnx_inference/t5_onnx_infer.py
+68
-0
libai/onnx_export/t5_to_onnx.py
libai/onnx_export/t5_to_onnx.py
+130
-0
No files found.
Too many changes to show.
To preserve performance only
332 of 332+
files are displayed.
Plain diff
Email patch
libai/models/swin_transformer.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
oneflow
as
flow
import
oneflow.nn
as
nn
from
flowvision.layers
import
trunc_normal_
from
flowvision.models
import
to_2tuple
from
libai.config.config
import
configurable
from
libai.layers
import
MLP
,
DropPath
,
LayerNorm
,
Linear
from
libai.utils
import
distributed
as
dist
def
window_partition
(
x
,
window_size
):
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query,key,value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
fused_bias_add_dropout
=
False
,
layer_idx
=
0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
flow
.
zeros
(
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
,
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
)
# 2*Wh-1 * 2*Ww-1, nH
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
0.02
)
# get pair-wise relative position index for each token inside the window
coords_h
=
flow
.
arange
(
self
.
window_size
[
0
])
coords_w
=
flow
.
arange
(
self
.
window_size
[
1
])
coords
=
flow
.
stack
(
flow
.
meshgrid
(
*
[
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
flow
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
=
(
relative_coords
[:,
:,
0
]
+
self
.
window_size
[
0
]
-
1
)
# shift to start from 0
relative_coords
[:,
:,
1
]
=
relative_coords
[:,
:,
1
]
+
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
=
relative_coords
[:,
:,
0
]
*
(
2
*
self
.
window_size
[
1
]
-
1
)
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
),
)
self
.
qkv
=
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
,
layer_idx
=
layer_idx
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
Linear
(
dim
,
dim
,
layer_idx
=
layer_idx
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
fused_bias_add_dropout
=
fused_bias_add_dropout
self
.
p
=
proj_drop
def
forward
(
self
,
x
,
mask
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
(
self
.
qkv
(
x
)
.
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)
.
permute
(
2
,
0
,
3
,
1
,
4
)
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
self
.
scale
# attn = flow.matmul(q, k.transpose(-2, -1))
attn
=
flow
.
matmul
(
q
,
k
,
transpose_b
=
True
)
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)
].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
,
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
unsqueeze_relative_position_bias
=
relative_position_bias
.
unsqueeze
(
0
)
attn
=
attn
+
unsqueeze_relative_position_bias
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
flow
.
matmul
(
attn
,
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
if
self
.
fused_bias_add_dropout
:
x
=
flow
.
_C
.
matmul
(
x
,
self
.
proj
.
weight
,
transpose_a
=
False
,
transpose_b
=
True
)
x
=
flow
.
_C
.
fused_bias_add_dropout
(
x
,
self
.
proj
.
bias
,
p
=
self
.
p
,
axis
=
2
)
else
:
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
SwinTransformerBlock
(
nn
.
Module
):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
layer_idx
=
0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
self
.
layer_idx
=
layer_idx
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
,
layer_idx
=
layer_idx
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
fused_bias_add_dropout
=
True
,
layer_idx
=
layer_idx
,
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
,
layer_idx
=
layer_idx
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
MLP
(
hidden_size
=
dim
,
ffn_hidden_size
=
mlp_hidden_dim
,
output_dropout_prob
=
drop
,
bias_gelu_fusion
=
True
,
bias_dropout_fusion
=
True
,
layer_idx
=
layer_idx
,
)
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
H
,
W
=
self
.
input_resolution
img_mask
=
flow
.
zeros
((
1
,
H
,
W
,
1
))
# 1 H W 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
=
cnt
+
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
)
)
attn_mask
=
attn_mask
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
else
:
attn_mask
=
None
self
.
register_buffer
(
"attn_mask"
,
attn_mask
)
def
forward
(
self
,
x
):
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
flow
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
flow
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchMerging
(
nn
.
Module
):
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
LayerNorm
,
layer_idx
=
0
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
,
layer_idx
=
layer_idx
)
self
.
norm
=
norm_layer
(
4
*
dim
,
layer_idx
=
layer_idx
)
self
.
layer_idx
=
layer_idx
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
f
"x size (
{
H
}
*
{
W
}
) are not even."
x
=
x
.
view
(
B
,
H
,
W
,
C
)
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
flow
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
class
PatchEmbed
(
nn
.
Module
):
"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
,
layer_idx
=
0
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
],
]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
).
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
,
layer_idx
=
layer_idx
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
(
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
]
),
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model(
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
class
BasicLayer
(
nn
.
Module
):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
downsample (nn.Module | None, optional): Downsample at the end of the layer. Default: None
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
norm_layer
=
LayerNorm
,
downsample
=
None
,
layer_id_offset
=
0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
layer_id_offset
=
layer_id_offset
# build blocks
self
.
blocks
=
nn
.
ModuleList
(
[
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
,
layer_idx
=
layer_id_offset
+
i
,
)
for
i
in
range
(
depth
)
]
)
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
,
layer_idx
=
layer_id_offset
+
depth
-
1
,
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
layer_idx
=
self
.
layer_id_offset
for
i
in
range
(
len
(
self
.
blocks
)):
x
=
x
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
))
x
=
self
.
blocks
[
i
](
x
)
layer_idx
+=
1
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
class
SwinTransformer
(
nn
.
Module
):
"""Swin Transformer in LiBai.
LiBai implement of:
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/pdf/2103.14030>`_
Args:
img_size (int, tuple(int)): Input image size. Default 224
patch_size (int, tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: libai.layers.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
loss_func (callable, optional): Loss function for computing the total loss
between logits and labels
"""
@
configurable
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.1
,
norm_layer
=
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
loss_func
=
None
,
**
kwargs
,
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
,
layer_idx
=
0
,
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
flow
.
zeros
(
1
,
num_patches
,
embed_dim
),
placement
=
dist
.
get_layer_placement
(
0
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
0.02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
[
x
.
item
()
for
x
in
flow
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))
]
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
ModuleList
()
layer_id_offset
=
0
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
),
),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
])
:
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
layer_id_offset
=
layer_id_offset
,
)
layer_id_offset
+=
depths
[
i_layer
]
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
,
layer_idx
=-
1
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
head
=
(
Linear
(
self
.
num_features
,
num_classes
,
layer_idx
=-
1
)
if
num_classes
>
0
else
nn
.
Identity
()
)
# Loss func
self
.
loss_func
=
nn
.
CrossEntropyLoss
()
if
loss_func
is
None
else
loss_func
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
classmethod
def
from_config
(
cls
,
cfg
):
return
{
"img_size"
:
cfg
.
img_size
,
"patch_size"
:
cfg
.
patch_size
,
"in_chans"
:
cfg
.
in_chans
,
"num_classes"
:
cfg
.
num_classes
,
"embed_dim"
:
cfg
.
embed_dim
,
"depths"
:
cfg
.
depths
,
"num_heads"
:
cfg
.
num_heads
,
"window_size"
:
cfg
.
window_size
,
"mlp_ratio"
:
cfg
.
mlp_ratio
,
"qkv_bias"
:
cfg
.
qkv_bias
,
"qk_scale"
:
cfg
.
qk_scale
,
"drop_rate"
:
cfg
.
drop_rate
,
"drop_path_rate"
:
cfg
.
drop_path_rate
,
"ape"
:
cfg
.
ape
,
"patch_norm"
:
cfg
.
patch_norm
,
"loss_func"
:
cfg
.
loss_func
,
}
def
forward_features
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x
.
transpose
(
1
,
2
))
# B C 1
x
=
flow
.
flatten
(
x
,
1
)
return
x
def
forward
(
self
,
images
,
labels
=
None
):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x
=
self
.
forward_features
(
images
)
x
=
self
.
head
(
x
)
if
labels
is
not
None
and
self
.
training
:
losses
=
self
.
loss_func
(
x
,
labels
)
return
{
"losses"
:
losses
}
else
:
return
{
"prediction_scores"
:
x
}
@
staticmethod
def
set_pipeline_stage_id
(
model
):
dist_utils
=
dist
.
get_dist_util
()
# Set pipeline parallelism stage_id
if
hasattr
(
model
.
patch_embed
,
"config"
):
# Old API in OneFlow 0.8
model
.
patch_embed
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
pos_drop
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
origin
,
SwinTransformerBlock
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
elif
isinstance
(
module_block
.
origin
,
PatchMerging
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
model
.
norm
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
head
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
avgpool
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
loss_func
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
else
:
model
.
patch_embed
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
pos_drop
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
SwinTransformerBlock
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
PatchMerging
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
model
.
norm
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
head
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
avgpool
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
loss_func
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
@
staticmethod
def
set_activation_checkpoint
(
model
):
for
module_block
in
model
.
modules
():
if
hasattr
(
module_block
,
"origin"
):
# Old API in OneFlow 0.8
if
isinstance
(
module_block
.
origin
,
SwinTransformerBlock
):
module_block
.
config
.
activation_checkpointing
=
True
else
:
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
SwinTransformerBlock
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
activation_checkpointing
=
True
libai/models/swin_transformer_v2.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
oneflow
as
flow
import
oneflow.nn
as
nn
import
oneflow.nn.functional
as
F
from
flowvision.layers
import
trunc_normal_
from
flowvision.models
import
to_2tuple
from
libai.config.config
import
configurable
from
libai.layers
import
MLP
,
DropPath
,
LayerNorm
,
Linear
from
libai.utils
import
distributed
as
dist
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value.
Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
pretrained_window_size
=
[
0
,
0
],
fused_bias_add_dropout
=
False
,
layer_idx
=
0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
pretrained_window_size
=
pretrained_window_size
self
.
fused_bias_add_dropout
=
fused_bias_add_dropout
self
.
num_heads
=
num_heads
self
.
layer_idx
=
layer_idx
self
.
p
=
proj_drop
self
.
logit_scale
=
nn
.
Parameter
(
flow
.
log
(
10
*
flow
.
ones
(
1
,
num_heads
,
1
,
1
,
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
),
requires_grad
=
True
,
)
# NOTE: generate meta network, using mlp to generate continuous relative position bias
self
.
cpb_mlp
=
nn
.
Sequential
(
Linear
(
2
,
512
,
bias
=
True
,
layer_idx
=
layer_idx
),
nn
.
ReLU
(
inplace
=
True
),
Linear
(
512
,
num_heads
,
bias
=
False
,
layer_idx
=
layer_idx
),
).
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
# NOTE: get relative_coords_table
relative_coords_h
=
flow
.
arange
(
-
(
self
.
window_size
[
0
]
-
1
),
self
.
window_size
[
0
],
dtype
=
flow
.
float32
)
relative_coords_w
=
flow
.
arange
(
-
(
self
.
window_size
[
1
]
-
1
),
self
.
window_size
[
1
],
dtype
=
flow
.
float32
)
relative_coords_table
=
(
flow
.
stack
(
flow
.
meshgrid
(
*
[
relative_coords_h
,
relative_coords_w
]))
.
permute
(
1
,
2
,
0
)
.
contiguous
()
.
unsqueeze
(
0
)
)
# 1, 2*Wh-1, 2*Ww-1, 2
# NOTE: For any relative coordinate, constrain it to -8~8 (window size)
if
pretrained_window_size
[
0
]
>
0
:
relative_coords_table
[:,
:,
:,
0
]
=
relative_coords_table
[:,
:,
:,
0
]
/
(
pretrained_window_size
[
0
]
-
1
)
relative_coords_table
[:,
:,
:,
1
]
=
relative_coords_table
[:,
:,
:,
1
]
/
(
pretrained_window_size
[
1
]
-
1
)
else
:
relative_coords_table
[:,
:,
:,
0
]
=
relative_coords_table
[:,
:,
:,
0
]
/
(
self
.
window_size
[
0
]
-
1
)
relative_coords_table
[:,
:,
:,
1
]
=
relative_coords_table
[:,
:,
:,
1
]
/
(
self
.
window_size
[
1
]
-
1
)
relative_coords_table
=
relative_coords_table
*
8
# NOTE: y=sign(x)*log(|x|+1)
relative_coords_table
=
(
flow
.
sign
(
relative_coords_table
)
*
flow
.
log2
(
flow
.
abs
(
relative_coords_table
)
+
1.0
)
/
math
.
log2
(
8.0
)
)
self
.
register_buffer
(
"relative_coords_table"
,
relative_coords_table
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
),
)
# NOTE: get pair-wise relative position index for each token inside the window
coords_h
=
flow
.
arange
(
self
.
window_size
[
0
])
coords_w
=
flow
.
arange
(
self
.
window_size
[
1
])
coords
=
flow
.
stack
(
flow
.
meshgrid
(
*
[
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
flow
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
=
(
relative_coords
[:,
:,
0
]
+
self
.
window_size
[
0
]
-
1
)
# shift to start from 0
relative_coords
[:,
:,
1
]
=
relative_coords
[:,
:,
1
]
+
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
=
relative_coords
[:,
:,
0
]
*
(
2
*
self
.
window_size
[
1
]
-
1
)
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
),
)
self
.
qkv
=
Linear
(
dim
,
dim
*
3
,
bias
=
False
,
layer_idx
=
layer_idx
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
flow
.
zeros
(
dim
,
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
)
self
.
v_bias
=
nn
.
Parameter
(
flow
.
zeros
(
dim
,
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
)
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
Linear
(
dim
,
dim
,
layer_idx
=
layer_idx
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
flow
.
concat
(
[
self
.
q_bias
,
flow
.
zeros
(
self
.
v_bias
.
shape
,
requires_grad
=
False
,
placement
=
dist
.
get_layer_placement
(
self
.
layer_idx
,
device_type
=
self
.
v_bias
.
placement
.
type
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
),
self
.
v_bias
,
],
dim
=
0
,
)
qkv
=
self
.
qkv
(
x
)
+
qkv_bias
.
unsqueeze
(
0
).
unsqueeze
(
0
)
qkv
=
qkv
.
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# NOTE: cosine attention
attn
=
F
.
normalize
(
q
,
dim
=-
1
)
@
F
.
normalize
(
k
,
dim
=-
1
).
transpose
(
-
2
,
-
1
)
# NOTE: a learnable scalar
logit_scale
=
flow
.
clamp
(
self
.
logit_scale
,
min
=-
1e6
,
max
=
math
.
log
(
1.0
/
0.01
)).
exp
()
attn
=
attn
*
logit_scale
# NOTE: use relative_coords_table and meta network to generate relative_position_bias
relative_position_bias_table
=
self
.
cpb_mlp
(
self
.
relative_coords_table
).
view
(
-
1
,
self
.
num_heads
)
relative_position_bias
=
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)
].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
# NOTE: constrained to a range of -16~16
relative_position_bias
=
16
*
flow
.
sigmoid
(
relative_position_bias
).
unsqueeze
(
0
)
attn
=
attn
+
relative_position_bias
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
if
self
.
fused_bias_add_dropout
:
x
=
flow
.
_C
.
matmul
(
x
,
self
.
proj
.
weight
,
transpose_a
=
False
,
transpose_b
=
True
)
x
=
flow
.
_C
.
fused_bias_add_dropout
(
x
,
self
.
proj
.
bias
,
p
=
self
.
p
,
axis
=
2
)
else
:
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
SwinTransformerBlock
(
nn
.
Module
):
r
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
pretrained_window_size
=
0
,
layer_idx
=
0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
self
.
layer_idx
=
layer_idx
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
,
layer_idx
=
layer_idx
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
pretrained_window_size
=
to_2tuple
(
pretrained_window_size
),
fused_bias_add_dropout
=
True
,
layer_idx
=
layer_idx
,
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
,
layer_idx
=
layer_idx
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
MLP
(
hidden_size
=
dim
,
ffn_hidden_size
=
mlp_hidden_dim
,
output_dropout_prob
=
drop
,
bias_gelu_fusion
=
True
,
bias_dropout_fusion
=
True
,
layer_idx
=
layer_idx
,
)
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
H
,
W
=
self
.
input_resolution
img_mask
=
flow
.
zeros
((
1
,
H
,
W
,
1
))
# 1 H W 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
=
cnt
+
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
(
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
))
.
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
)
else
:
attn_mask
=
None
self
.
register_buffer
(
"attn_mask"
,
attn_mask
)
def
forward
(
self
,
x
):
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
shortcut
=
x
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
flow
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
mask
=
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
flow
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# NOTE: res-post-norm
x
=
shortcut
+
self
.
drop_path
(
self
.
norm1
(
x
))
# NOTE: res-post-norm
x
=
x
+
self
.
drop_path
(
self
.
norm2
(
self
.
mlp
(
x
)))
return
x
class
PatchMerging
(
nn
.
Module
):
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: libai.layers.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
LayerNorm
,
layer_idx
=
0
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
,
layer_idx
=
layer_idx
)
# NOTE: swinv2-> 2*dim, swin-> 4*dim
self
.
norm
=
norm_layer
(
2
*
dim
,
layer_idx
=
layer_idx
)
self
.
layer_idx
=
layer_idx
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
f
"x size (
{
H
}
*
{
W
}
) are not even."
x
=
x
.
view
(
B
,
H
,
W
,
C
)
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
flow
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
# NOTE: post-res-norm, a change that swin-v2 compared to swin
x
=
self
.
reduction
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
BasicLayer
(
nn
.
Module
):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer.
Default: None
pretrained_window_size (int): Local window size in pre-training.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
norm_layer
=
LayerNorm
,
downsample
=
None
,
pretrained_window_size
=
0
,
layer_id_offset
=
0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
layer_id_offset
=
layer_id_offset
# build blocks
self
.
blocks
=
nn
.
ModuleList
(
[
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
,
pretrained_window_size
=
pretrained_window_size
,
layer_idx
=
layer_id_offset
+
i
,
)
for
i
in
range
(
depth
)
]
)
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
,
layer_idx
=
layer_id_offset
+
depth
-
1
,
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
layer_idx
=
self
.
layer_id_offset
for
blk
in
self
.
blocks
:
x
=
x
.
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
,
device_type
=
x
.
placement
.
type
)
)
x
=
blk
(
x
)
layer_idx
+=
1
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
def
_init_respostnorm
(
self
):
for
blk
in
self
.
blocks
:
nn
.
init
.
constant_
(
blk
.
norm1
.
bias
,
0
)
nn
.
init
.
constant_
(
blk
.
norm1
.
weight
,
0
)
nn
.
init
.
constant_
(
blk
.
norm2
.
bias
,
0
)
nn
.
init
.
constant_
(
blk
.
norm2
.
weight
,
0
)
class
PatchEmbed
(
nn
.
Module
):
r
"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
,
layer_idx
=
0
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
).
to_global
(
placement
=
dist
.
get_layer_placement
(
layer_idx
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
(
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
]
),
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model
\
(
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
class
SwinTransformerV2
(
nn
.
Module
):
r
"""Swin Transformer
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
"""
@
configurable
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.1
,
norm_layer
=
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
pretrained_window_sizes
=
[
0
,
0
,
0
,
0
],
loss_func
=
None
,
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
,
layer_idx
=
0
,
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
flow
.
zeros
(
1
,
num_patches
,
embed_dim
,
placement
=
dist
.
get_layer_placement
(
0
),
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
)
)
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
0.02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
[
x
.
item
()
for
x
in
flow
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))
]
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
ModuleList
()
layer_id_offset
=
0
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
),
),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
])
:
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
pretrained_window_size
=
pretrained_window_sizes
[
i_layer
],
layer_id_offset
=
layer_id_offset
,
)
layer_id_offset
+=
depths
[
i_layer
]
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
,
layer_idx
=-
1
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
head
=
(
Linear
(
self
.
num_features
,
num_classes
,
layer_idx
=-
1
)
if
num_classes
>
0
else
nn
.
Identity
()
)
self
.
loss_func
=
nn
.
CrossEntropyLoss
()
if
loss_func
is
None
else
loss_func
self
.
apply
(
self
.
_init_weights
)
for
bly
in
self
.
layers
:
bly
.
_init_respostnorm
()
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
classmethod
def
from_config
(
cls
,
cfg
):
return
{
"img_size"
:
cfg
.
img_size
,
"patch_size"
:
cfg
.
patch_size
,
"in_chans"
:
cfg
.
in_chans
,
"num_classes"
:
cfg
.
num_classes
,
"embed_dim"
:
cfg
.
embed_dim
,
"depths"
:
cfg
.
depths
,
"num_heads"
:
cfg
.
num_heads
,
"window_size"
:
cfg
.
window_size
,
"mlp_ratio"
:
cfg
.
mlp_ratio
,
"qkv_bias"
:
cfg
.
qkv_bias
,
"drop_rate"
:
cfg
.
drop_rate
,
"drop_path_rate"
:
cfg
.
drop_path_rate
,
"ape"
:
cfg
.
ape
,
"patch_norm"
:
cfg
.
patch_norm
,
"pretrained_window_sizes"
:
cfg
.
pretrained_window_sizes
,
"loss_func"
:
cfg
.
loss_func
,
}
def
forward_features
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x
.
transpose
(
1
,
2
))
# B C 1
x
=
flow
.
flatten
(
x
,
1
)
return
x
def
forward
(
self
,
images
,
labels
=
None
):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x
=
self
.
forward_features
(
images
)
x
=
self
.
head
(
x
)
if
labels
is
not
None
and
self
.
training
:
losses
=
self
.
loss_func
(
x
,
labels
)
return
{
"losses"
:
losses
}
else
:
return
{
"prediction_scores"
:
x
}
@
staticmethod
def
set_pipeline_stage_id
(
model
):
dist_utils
=
dist
.
get_dist_util
()
# Set pipeline parallelism stage_id
if
hasattr
(
model
.
patch_embed
,
"config"
):
# Old API in OneFlow 0.8
model
.
patch_embed
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
pos_drop
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
origin
,
SwinTransformerBlock
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
elif
isinstance
(
module_block
.
origin
,
PatchMerging
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
model
.
norm
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
head
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
avgpool
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
loss_func
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
else
:
model
.
patch_embed
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
pos_drop
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
SwinTransformerBlock
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
PatchMerging
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
model
.
norm
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
head
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
avgpool
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
loss_func
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
@
staticmethod
def
set_activation_checkpoint
(
model
):
for
module_block
in
model
.
modules
():
if
hasattr
(
module_block
,
"origin"
):
# Old API in OneFlow 0.8
if
isinstance
(
module_block
.
origin
,
SwinTransformerBlock
):
module_block
.
config
.
activation_checkpointing
=
True
else
:
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
SwinTransformerBlock
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
activation_checkpointing
=
True
libai/models/t5_model.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
oneflow
as
flow
import
oneflow.nn
as
nn
from
libai.config
import
configurable
from
libai.layers
import
(
Embedding
,
LayerNorm
,
LMLogits
,
ParallelCrossEntropyLoss
,
TransformerLayer
,
VocabEmbedding
,
)
from
libai.layers.attention
import
AttnMaskType
from
libai.models.utils
import
init_method_normal
,
scaled_init_method_normal
from
libai.utils
import
distributed
as
dist
class
ExtendedMask
(
flow
.
nn
.
Module
):
def
forward
(
self
,
attention_mask
):
return
attention_mask
.
unsqueeze
(
1
)
class
T5Embedding
(
flow
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
vocab_size
,
max_sequence_length
,
embedding_dropout_prob
,
init_method
=
flow
.
nn
.
init
.
xavier_normal_
,
amp_enabled
=
False
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
word_embeddings
=
VocabEmbedding
(
num_embeddings
=
vocab_size
,
embedding_dim
=
hidden_size
,
init_method
=
init_method
,
amp_enabled
=
amp_enabled
,
)
self
.
position_embeddings
=
Embedding
(
num_embeddings
=
max_sequence_length
,
embedding_dim
=
hidden_size
,
init_method
=
init_method
,
amp_enabled
=
amp_enabled
,
)
self
.
position_ids
=
flow
.
arange
(
max_sequence_length
,
dtype
=
flow
.
long
,
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
placement
=
dist
.
get_layer_placement
(
0
),
).
unsqueeze
(
0
)
self
.
embedding_dropout
=
flow
.
nn
.
Dropout
(
embedding_dropout_prob
)
def
forward
(
self
,
input_ids
,
past_length
=
0
):
seq_length
=
input_ids
.
size
()[
1
]
position_ids
=
self
.
position_ids
[:,
past_length
:
past_length
+
seq_length
]
position_ids
=
position_ids
.
expand_as
(
input_ids
).
to_global
(
sbp
=
input_ids
.
sbp
)
word_embeddings
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
word_embeddings
+
position_embeddings
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
class
T5Model
(
flow
.
nn
.
Module
):
"""T5 Model that outputs logits.
Args:
vocab_size (int): The size of vocabulary file.
hidden_size (int): The size of hidden states.
hidden_layers (int): The number of ``TransformerLayer`` in the encoder and decoder.
num_attention_heads (int):
The number of attention heads for each attention layer of ``TransformerLayer``.
intermediate_size (int):
The size of intermediate layer in feed-forward network for each ``TransformerLayer``.
embedding_dropout_prob (float): The dropout ratio for the output of T5Embedding Layer.
hidden_dropout_prob (float): The dropout ratio for the output for each ``TransformerLayer``.
attention_probs_dropout_prob (float):
The dropout ratio for the output of each attention layer in ``TransformerLayer``.
max_position_embeddings (int):
Max sequence length of input, defines the shape of Position Embeddings
in ``T5Emebedding``.
initializer_range (float, optional):
Sigma of the normal distribution in the initialization method. Defaults to 0.02.
layernorm_eps (float, optional): The epsilon of LayerNorm layer. Defaults to 1e-12.
bias_gelu_fusion (bool, optional):
Whether or not to fuse the computing of bias and gelu. Defaults to ``False``.
bias_dropout_fusion (bool, optional):
Whether or not to fuse the computing of dropout and bias. Defaults to ``False``.
scale_mask_softmax_fusion (bool, optional):
Whether to fuse the computing of mask and softmax in attention layers.
Defaults to ``False``.
apply_query_key_layer_scaling (bool, optional):
Whether or not to use layer index related scaling in computing attention scores.
If ``True``, the scaling factor equals to sqrt(d) * (layer_index + 1).
Defaults to ``True``.
apply_residual_post_layernorm (bool, optional):
If set ``True``, use original BERT residual connection ordering otherwise use Megatron
BERT residual connection which is more stable when scaling model size introduced in
https://arxiv.org/pdf/1909.08053.pdf.
Default: ``False``.
amp_enabled (bool, optional):
Whether or not to set fp16 for embedding weight in T5 model. Defaults to ``False``.
"""
@
configurable
def
__init__
(
self
,
vocab_size
,
hidden_size
,
hidden_layers
,
num_attention_heads
,
intermediate_size
,
embedding_dropout_prob
,
hidden_dropout_prob
,
attention_probs_dropout_prob
,
max_position_embeddings
,
initializer_range
=
0.02
,
layernorm_eps
=
1e-12
,
bias_gelu_fusion
=
False
,
bias_dropout_fusion
=
False
,
scale_mask_softmax_fusion
=
False
,
apply_query_key_layer_scaling
=
True
,
apply_residual_post_layernorm
=
False
,
amp_enabled
=
False
,
)
->
None
:
super
().
__init__
()
init_method
=
init_method_normal
(
initializer_range
)
scaled_init_method
=
scaled_init_method_normal
(
initializer_range
,
hidden_layers
)
self
.
embedding
=
T5Embedding
(
hidden_size
=
hidden_size
,
vocab_size
=
vocab_size
,
max_sequence_length
=
max_position_embeddings
,
embedding_dropout_prob
=
embedding_dropout_prob
,
init_method
=
init_method
,
amp_enabled
=
amp_enabled
,
)
self
.
extended_attn_mask
=
ExtendedMask
()
encoder_layers
=
flow
.
nn
.
ModuleList
(
[
TransformerLayer
(
hidden_size
=
hidden_size
,
ffn_hidden_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
is_decoder
=
False
,
attention_dropout_prob
=
attention_probs_dropout_prob
,
output_dropout_prob
=
hidden_dropout_prob
,
layernorm_epsilon
=
layernorm_eps
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
bias_gelu_fusion
=
bias_gelu_fusion
,
bias_dropout_fusion
=
bias_dropout_fusion
,
scale_mask_softmax_fusion
=
scale_mask_softmax_fusion
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
apply_residual_post_layernorm
=
apply_residual_post_layernorm
,
attn_mask_type
=
AttnMaskType
.
padding
,
layer_idx
=
i
,
)
for
i
in
range
(
hidden_layers
)
]
)
encoder_final_layernorm
=
LayerNorm
(
(
hidden_size
,),
eps
=
layernorm_eps
,
layer_idx
=
hidden_layers
-
1
,
)
self
.
encoder
=
flow
.
nn
.
Sequential
()
self
.
encoder
.
add_module
(
"layers"
,
encoder_layers
)
self
.
encoder
.
add_module
(
"final_layernorm"
,
encoder_final_layernorm
)
decoder_layers
=
flow
.
nn
.
ModuleList
(
[
TransformerLayer
(
hidden_size
=
hidden_size
,
ffn_hidden_size
=
intermediate_size
,
num_attention_heads
=
num_attention_heads
,
is_decoder
=
True
,
attention_dropout_prob
=
attention_probs_dropout_prob
,
output_dropout_prob
=
hidden_dropout_prob
,
layernorm_epsilon
=
layernorm_eps
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
bias_gelu_fusion
=
bias_gelu_fusion
,
bias_dropout_fusion
=
bias_dropout_fusion
,
scale_mask_softmax_fusion
=
scale_mask_softmax_fusion
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attn_mask_type
=
AttnMaskType
.
padding
,
layer_idx
=
i
,
)
for
i
in
range
(
hidden_layers
,
2
*
hidden_layers
)
]
)
decoder_final_layernorm
=
LayerNorm
(
(
hidden_size
,),
eps
=
layernorm_eps
,
layer_idx
=
2
*
hidden_layers
-
1
,
)
self
.
decoder
=
flow
.
nn
.
Sequential
()
self
.
decoder
.
add_module
(
"layers"
,
decoder_layers
)
self
.
decoder
.
add_module
(
"final_layernorm"
,
decoder_final_layernorm
)
self
.
past_key_values
=
[
None
]
*
len
(
self
.
decoder
.
layers
)
self
.
encoder_states
=
None
self
.
past_length
=
0
self
.
lm_head
=
LMLogits
(
vocab_size
,
bias
=
True
)
@
classmethod
def
from_config
(
cls
,
cfg
):
return
{
"vocab_size"
:
cfg
.
vocab_size
,
"hidden_size"
:
cfg
.
hidden_size
,
"hidden_layers"
:
cfg
.
hidden_layers
,
"num_attention_heads"
:
cfg
.
num_attention_heads
,
"intermediate_size"
:
cfg
.
intermediate_size
,
"embedding_dropout_prob"
:
cfg
.
embedding_dropout_prob
,
"hidden_dropout_prob"
:
cfg
.
hidden_dropout_prob
,
"attention_probs_dropout_prob"
:
cfg
.
attention_probs_dropout_prob
,
"max_position_embeddings"
:
cfg
.
max_position_embeddings
,
"initializer_range"
:
cfg
.
initializer_range
,
"layernorm_eps"
:
cfg
.
layernorm_eps
,
"bias_gelu_fusion"
:
cfg
.
bias_gelu_fusion
,
"bias_dropout_fusion"
:
cfg
.
bias_dropout_fusion
,
"scale_mask_softmax_fusion"
:
cfg
.
scale_mask_softmax_fusion
,
"apply_query_key_layer_scaling"
:
cfg
.
apply_query_key_layer_scaling
,
"apply_residual_post_layernorm"
:
cfg
.
apply_residual_post_layernorm
,
"amp_enabled"
:
cfg
.
amp_enabled
,
}
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
encoder_attn_mask
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
use_cache
=
False
,
):
"""
Args:
encoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for encoder.
decoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for decoder.
encoder_attn_mask (flow.BoolTensor):
Mask for encoder to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on subsequent token indices.
Mask values have the same meaning as encoder_attn_mask.
encoder_decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on encoder padded token indices.
Mask values have the same meaning as encoder_attn_mask.
use_cache (bool, optional):
It will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
Returns:
flow.Tensor: logits
"""
encoder_input_ids
=
encoder_input_ids
.
to_global
(
placement
=
dist
.
get_layer_placement
(
0
))
decoder_input_ids
=
decoder_input_ids
.
to_global
(
placement
=
dist
.
get_layer_placement
(
0
))
encoder_attn_mask
=
encoder_attn_mask
.
to_global
(
placement
=
dist
.
get_layer_placement
(
0
))
decoder_attn_mask
=
decoder_attn_mask
.
to_global
(
placement
=
dist
.
get_layer_placement
(
0
))
encoder_decoder_attn_mask
=
encoder_decoder_attn_mask
.
to_global
(
placement
=
dist
.
get_layer_placement
(
0
)
)
if
use_cache
and
self
.
encoder_states
is
not
None
:
encoder_states
=
self
.
encoder_states
else
:
self
.
set_cache
(
encoder_states
=
None
,
past_key_values
=
None
)
encoder_attn_mask
=
self
.
extended_attn_mask
(
encoder_attn_mask
)
enc_embedding_output
=
self
.
embedding
(
encoder_input_ids
)
enc_hidden_states
=
enc_embedding_output
for
layer
in
self
.
encoder
.
layers
:
enc_hidden_states
=
layer
(
enc_hidden_states
,
encoder_attn_mask
)
encoder_states
=
self
.
encoder
.
final_layernorm
(
enc_hidden_states
)
decoder_attn_mask
=
self
.
extended_attn_mask
(
decoder_attn_mask
)
encoder_decoder_attn_mask
=
self
.
extended_attn_mask
(
encoder_decoder_attn_mask
)
dec_embedding_output
=
self
.
embedding
(
decoder_input_ids
,
self
.
past_length
)
dec_hidden_states
=
dec_embedding_output
if
use_cache
:
presents
=
[]
for
layer
,
past_key_value
in
zip
(
self
.
decoder
.
layers
,
self
.
past_key_values
):
dec_hidden_states
=
layer
(
dec_hidden_states
,
decoder_attn_mask
,
encoder_states
,
encoder_decoder_attn_mask
,
past_key_value
=
past_key_value
,
use_cache
=
use_cache
,
)
if
use_cache
:
dec_hidden_states
,
present
=
dec_hidden_states
presents
.
append
(
present
)
if
use_cache
:
self
.
set_cache
(
encoder_states
,
past_key_values
=
presents
)
decoder_states
=
self
.
decoder
.
final_layernorm
(
dec_hidden_states
)
logits
=
self
.
lm_head
(
decoder_states
,
self
.
embedding
.
word_embeddings
.
weight
)
return
logits
def
set_cache
(
self
,
encoder_states
,
past_key_values
):
self
.
encoder_states
=
encoder_states
self
.
past_length
=
0
if
past_key_values
is
None
else
past_key_values
[
0
][
0
].
shape
[
2
]
if
past_key_values
is
None
:
past_key_values
=
[
None
]
*
len
(
self
.
decoder
.
layers
)
assert
len
(
past_key_values
)
==
len
(
self
.
decoder
.
layers
),
(
f
"past_key_values's length
{
len
(
past_key_values
)
}
doesn't match "
f
"decoder num_layers' length
{
self
.
decoder
.
layers
}
"
)
self
.
past_key_values
=
past_key_values
class
T5Loss
(
flow
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
lm_loss
=
ParallelCrossEntropyLoss
()
def
forward
(
self
,
logits
,
lm_labels
,
loss_mask
):
lm_loss
=
self
.
lm_loss
(
logits
,
lm_labels
)
loss_mask
=
loss_mask
.
to_global
(
placement
=
lm_loss
.
placement
)
loss_mask
=
loss_mask
.
float
()
denominator
=
loss_mask
.
sum
().
to_global
(
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
])
)
lm_loss
=
flow
.
_C
.
amp_white_identity
(
lm_loss
)
lm_loss
=
flow
.
_C
.
amp_black_identity
(
lm_loss
)
masked_lm_loss
=
flow
.
sum
(
lm_loss
.
view
(
-
1
)
*
loss_mask
.
view
(
-
1
))
/
denominator
masked_lm_loss
=
masked_lm_loss
.
to_global
(
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
partial_sum
,
flow
.
sbp
.
broadcast
])
)
return
{
"masked_lm_loss"
:
masked_lm_loss
}
class
T5ForPreTraining
(
flow
.
nn
.
Module
):
"""
T5 Model with classification head on top.
"""
def
__init__
(
self
,
cfg
)
->
None
:
super
().
__init__
()
self
.
t5_model
=
T5Model
(
cfg
)
self
.
loss_func
=
T5Loss
()
def
set_cache
(
self
,
encoder_states
,
past_key_values
):
self
.
t5_model
.
set_cache
(
encoder_states
,
past_key_values
)
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
encoder_attn_mask
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
lm_labels
=
None
,
loss_mask
=
None
,
use_cache
=
False
,
):
"""
Args:
encoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for encoder.
decoder_input_ids (flow.LongTensor):
Indices of input sequence tokens in vocabulary for decoder.
encoder_attn_mask (flow.BoolTensor):
Mask for encoder to avoid performing attention on
padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on subsequent token indices.
Mask values have the same meaning as encoder_attn_mask.
encoder_decoder_attn_mask (flow.BoolTensor):
Mask for decoder to avoid performing attention on encoder padded token indices.
Mask values have the same meaning as encoder_attn_mask.
lm_labels (flow.LongTensor, optional): Labels for computing the masked
language modeling loss. Indices should be in `[-1, 0, ..., config.vocab_size]`.
None for evaluating.
loss_mask (flow.BoolTensor, optional):
Mask to avoid performing loss computing on ignored tokens.
Tokens with indices set to `-1` are ignored (masked), the loss is only computed
for the tokens with labels in `[0, ..., config.vocab_size]`.
None for evaluating.
use_cache (bool, optional):
It will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"masked_lm_loss": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
logits
=
self
.
t5_model
(
encoder_input_ids
,
decoder_input_ids
,
encoder_attn_mask
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
use_cache
=
use_cache
,
)
if
lm_labels
is
not
None
:
lm_loss
=
self
.
loss_func
(
logits
,
lm_labels
,
loss_mask
)
return
lm_loss
else
:
return
{
"prediction_scores"
:
logits
,
}
@
staticmethod
def
set_pipeline_stage_id
(
model
):
dist_utils
=
dist
.
get_dist_util
()
# Set pipeline parallelism stage_id
if
hasattr
(
model
.
t5_model
.
encoder
.
final_layernorm
,
"config"
):
# Old API in OneFlow 0.8
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
origin
,
T5Embedding
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
elif
isinstance
(
module_block
.
origin
,
ExtendedMask
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
elif
isinstance
(
module_block
.
origin
,
TransformerLayer
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
elif
isinstance
(
module_block
.
origin
,
LMLogits
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
elif
isinstance
(
module_block
.
origin
,
T5Loss
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
t5_model
.
encoder
.
final_layernorm
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
model
.
t5_model
.
encoder
.
final_layernorm
.
layer_idx
),
dist
.
get_layer_placement
(
model
.
t5_model
.
encoder
.
final_layernorm
.
layer_idx
),
)
model
.
t5_model
.
decoder
.
final_layernorm
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
model
.
t5_model
.
decoder
.
final_layernorm
.
layer_idx
),
dist
.
get_layer_placement
(
model
.
t5_model
.
decoder
.
final_layernorm
.
layer_idx
),
)
else
:
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
T5Embedding
):
module_block
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
ExtendedMask
):
module_block
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
TransformerLayer
):
module_block
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
LMLogits
):
module_block
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
T5Loss
):
module_block
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
t5_model
.
encoder
.
final_layernorm
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
model
.
t5_model
.
encoder
.
final_layernorm
.
layer_idx
),
dist
.
get_layer_placement
(
model
.
t5_model
.
encoder
.
final_layernorm
.
layer_idx
),
)
model
.
t5_model
.
decoder
.
final_layernorm
.
to
(
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
model
.
t5_model
.
decoder
.
final_layernorm
.
layer_idx
),
dist
.
get_layer_placement
(
model
.
t5_model
.
decoder
.
final_layernorm
.
layer_idx
),
)
libai/models/utils/__init__.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.graph_base
import
GraphBase
from
.weight_init
import
init_method_normal
,
scaled_init_method_normal
from
.model_loader.base_loader
import
ModelLoaderHuggerFace
,
ModelLoaderLiBai
from
.model_loader.bert_loader
import
BertLoaderHuggerFace
,
BertLoaderLiBai
from
.model_loader.roberta_loader
import
RobertaLoaderHuggerFace
,
RobertaLoaderLiBai
from
.model_loader.gpt_loader
import
GPT2LoaderHuggerFace
,
GPT2LoaderLiBai
from
.model_loader.swin_loader
import
SwinLoaderHuggerFace
,
SwinLoaderLiBai
from
.model_loader.swinv2_loader
import
SwinV2LoaderHuggerFace
,
SwinV2LoaderLiBai
from
.model_loader.vit_loader
import
ViTLoaderHuggerFace
,
ViTLoaderLiBai
libai/models/utils/graph_base.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
oneflow
as
flow
from
oneflow
import
nn
from
libai.layers
import
TransformerLayer
from
libai.utils
import
distributed
as
dist
logger
=
logging
.
getLogger
(
__name__
)
class
GraphBase
(
nn
.
Graph
):
def
__init__
(
self
,
model
:
nn
.
Module
,
optimizer
:
flow
.
optim
.
Optimizer
=
None
,
lr_scheduler
:
flow
.
optim
.
lr_scheduler
=
None
,
fp16
=
False
,
activation_checkpoint
=
False
,
grad_acc_steps
=
1
,
zero_optim
=
False
,
zero_stage
=
0
,
is_train
=
True
,
auto_parallel_conf
=
None
,
):
super
().
__init__
()
self
.
model
=
model
self
.
is_train
=
is_train
if
is_train
:
self
.
add_optimizer
(
optimizer
,
lr_sch
=
lr_scheduler
)
if
fp16
:
self
.
config
.
enable_amp
(
True
)
grad_scaler
=
flow
.
amp
.
GradScaler
(
init_scale
=
65536.0
*
dist
.
get_data_parallel_size
(),
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
2000
,
)
self
.
set_grad_scaler
(
grad_scaler
)
if
grad_acc_steps
>
1
:
self
.
config
.
set_gradient_accumulation_steps
(
grad_acc_steps
)
if
activation_checkpoint
:
self
.
set_activation_checkpoint
()
if
zero_optim
:
self
.
config
.
enable_zero
(
True
,
stage
=
zero_stage
)
self
.
set_pipeline_stage_id
()
self
.
config
.
allow_fuse_add_to_output
(
True
)
self
.
config
.
allow_fuse_model_update_ops
(
True
)
self
.
config
.
allow_fuse_cast_scale
(
True
)
# Enable cuda stream for computation and communication as the same stream.
# This will reduce memory when using model parallelism.
dist_util
=
dist
.
get_dist_util
()
if
dist_util
.
is_tensor_model_parallel
()
or
dist_util
.
is_pipeline_model_parallel
():
flow
.
boxing
.
nccl
.
enable_use_compute_stream
(
True
)
# auto_parallel
if
auto_parallel_conf
is
not
None
and
auto_parallel_conf
.
enabled
:
try
:
self
.
config
.
enable_auto_parallel
(
True
)
self
.
config
.
enable_auto_parallel_ignore_user_sbp_config
(
auto_parallel_conf
.
enable_auto_parallel_ignore_user_sbp_config
)
self
.
config
.
set_auto_parallel_computation_cost_ratio
(
0.05
)
self
.
config
.
set_auto_parallel_wait_time
(
1.65e4
)
self
.
config
.
enable_auto_parallel_trunk_algo
(
auto_parallel_conf
.
trunk_algo
)
self
.
config
.
enable_auto_parallel_sbp_collector
(
auto_parallel_conf
.
sbp_collector
)
except
RuntimeWarning
:
import
warnings
warnings
.
warn
(
"The version of oneflow don't support auto_parallel.
\n
"
"Please reinstall the oneflow nightly:
\n
"
"python3 -m pip install --pre oneflow -f https://staging.oneflow.info/branch/master/[PLATFORM]"
# noqa
)
def
build
(
self
,
**
kwargs
):
if
self
.
is_train
:
logger
.
info
(
"Start compiling the train graph which may take some time. "
"Please wait for a moment ..."
)
loss_dict
=
self
.
model
(
**
kwargs
)
losses
=
sum
(
v
for
k
,
v
in
loss_dict
.
items
()
if
"loss"
in
k
)
losses
.
backward
()
return
loss_dict
else
:
logger
.
info
(
"Start compiling the eval graph which may take some time. "
"Please wait for a moment ..."
)
return
self
.
model
(
**
kwargs
)
def
set_activation_checkpoint
(
self
):
if
hasattr
(
self
.
model
,
"origin"
):
if
hasattr
(
type
(
self
.
model
.
origin
),
"set_activation_checkpoint"
):
type
(
self
.
model
.
origin
).
set_activation_checkpoint
(
self
.
model
)
else
:
for
module_block
in
self
.
model
.
modules
():
if
isinstance
(
module_block
.
origin
,
TransformerLayer
):
module_block
.
config
.
activation_checkpointing
=
True
else
:
if
hasattr
(
type
(
self
.
model
.
to
(
nn
.
Module
)),
"set_activation_checkpoint"
):
type
(
self
.
model
.
to
(
nn
.
Module
)).
set_activation_checkpoint
(
self
.
model
)
else
:
for
module_block
in
self
.
model
.
modules
():
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
TransformerLayer
):
module_block
.
to
(
nn
.
graph
.
GraphModule
).
activation_checkpointing
=
True
def
set_pipeline_stage_id
(
self
):
if
hasattr
(
self
.
model
,
"origin"
):
if
hasattr
(
type
(
self
.
model
.
origin
),
"set_pipeline_stage_id"
):
type
(
self
.
model
.
origin
).
set_pipeline_stage_id
(
self
.
model
)
else
:
if
hasattr
(
type
(
self
.
model
.
to
(
nn
.
Module
)),
"set_pipeline_stage_id"
):
type
(
self
.
model
.
to
(
nn
.
Module
)).
set_pipeline_stage_id
(
self
.
model
)
libai/models/utils/model_loader/README.md
0 → 100644
View file @
9fdb7dab
## Introduction
Here are the Weight Loaders currently supported in LiBai. You can use them to load the models in LiBai and the models stored on the huggingface.
## Weight Loader On LiBai
-
[
BERT Loader
](
./bert_loader.py
)
-
[
RoBERTa Loader
](
./roberta_loader.py
)
-
[
GPT2 Loader
](
./gpt_loader.py
)
-
[
MT5 Loader
](
../../../../projects/MT5/utils/mt5_loader.py
)
-
[
SWIN Loader
](
./swin_loader.py
)
-
[
SWIN2 Loader
](
./swinv2_loader.py
)
-
[
VIT Loader
](
./vit_loader.py
)
## How To Use
We can easily load pretrained BERT as following:
```
python
import
libai
from
libai.models.utils
import
BertLoaderHuggerFace
,
BertLoaderLiBai
from
configs.common.models.bert
import
cfg
# load huggingface weight
loader
=
BertLoaderHuggerFace
(
model
=
libai
.
models
.
BertModel
,
libai_cfg
=
cfg
,
pretrained_model_path
=
"path/to/huggingface_pretrained_model_directory"
,
hidden_dropout_prob
=
0
,
apply_residual_post_layernorm
=
True
)
bert
=
loader
.
load
()
# load libai weight
loader
=
BertLoaderLiBai
(
model
=
libai
.
models
.
BertModel
,
libai_cfg
=
cfg
,
pretrained_model_path
=
'path/to/libai_pretrained_model_directory'
,
hidden_dropout_prob
=
0
,
)
bert
=
loader
.
load
()
```
libai/models/utils/model_loader/__init__.py
0 → 100644
View file @
9fdb7dab
libai/models/utils/model_loader/base_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
copy
import
logging
import
os
import
omegaconf
import
oneflow
as
flow
from
termcolor
import
colored
import
libai.utils.distributed
as
dist
from
libai.config
import
LazyCall
from
libai.models.build
import
build_model
logger
=
logging
.
getLogger
(
__name__
)
WEIGHTS_NAME_PT
=
"pytorch_model.bin"
CONFIG_NAME
=
"config.json"
def
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
):
"""load state dict into model
Args:
model_to_load (nn.Module): Model to be loaded.
state_dict (OrderedDict): State dict of pretrained model.
start_prefix (str): Start prefix.
Returns:
list: error message about loading.
"""
metadata
=
getattr
(
state_dict
,
"_metadata"
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
error_msgs
=
[]
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
args
=
(
state_dict
,
prefix
,
local_metadata
,
True
,
[],
[],
error_msgs
)
module
.
_load_from_state_dict
(
*
args
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
"."
)
load
(
model_to_load
,
prefix
=
start_prefix
)
return
error_msgs
class
ModelLoader
(
object
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
"""Class used to load the [`transformers`](https://huggingface.co/models) pretrained model
or `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
self
.
model
=
model
self
.
libai_cfg
=
libai_cfg
self
.
pretrained_model_path
=
pretrained_model_path
self
.
kwargs
=
kwargs
self
.
output_loading_info
=
kwargs
.
pop
(
"output_loading_info"
,
False
)
def
_state_dict_to_global
(
self
,
flow_state_dict
=
None
,
mode
=
"libai"
):
"""Tensor in OneFlow state dict to global according to model's sbp and placement.
Args:
flow_state_dict (OrderedDict): State dict of OneFlow's pretrained model.
"""
assert
mode
in
[
"libai"
,
"pytorch"
],
f
"not support for mode
{
mode
}
"
if
mode
==
"libai"
or
dist
.
is_main_process
():
prefix
=
self
.
base_model_prefix_2
# Checkpoint
has_prefix_module
=
any
(
s
.
startswith
(
self
.
base_model_prefix_2
)
for
s
in
flow_state_dict
.
keys
()
)
# Module
expects_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
self
.
model
.
state_dict
().
keys
()
)
start_prefix
=
""
if
has_prefix_module
else
prefix
+
"."
loaded_keys
=
[
start_prefix
+
key
for
key
in
flow_state_dict
.
keys
()]
else
:
prefix
,
has_prefix_module
,
expects_prefix_module
,
loaded_keys
=
[
None
]
*
4
flow_state_dict
=
collections
.
OrderedDict
()
prefix
=
dist
.
broadcast_py_object
(
prefix
,
src
=
0
)
has_prefix_module
=
dist
.
broadcast_py_object
(
has_prefix_module
,
src
=
0
)
expects_prefix_module
=
dist
.
broadcast_py_object
(
expects_prefix_module
,
src
=
0
)
loaded_keys
=
dist
.
broadcast_py_object
(
loaded_keys
,
src
=
0
)
# to global
for
key
,
value
in
self
.
model
.
state_dict
().
items
():
if
not
expects_prefix_module
:
key
=
prefix
+
"."
+
key
if
key
in
loaded_keys
:
if
not
has_prefix_module
:
key
=
"."
.
join
(
key
.
split
(
"."
)[
1
:])
if
mode
==
"pytorch"
:
flow_state_dict
[
key
]
=
flow
.
to_global
(
flow_state_dict
[
key
]
if
dist
.
is_main_process
()
else
flow
.
Tensor
(
None
),
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cpu"
,
ranks
=
[
0
]),
)
flow_state_dict
[
key
]
=
flow
.
to_global
(
flow_state_dict
[
key
],
sbp
=
value
.
sbp
,
placement
=
flow
.
placement
(
"cpu"
,
ranks
=
list
(
value
.
placement
.
ranks
)),
)
return
flow_state_dict
def
_load_pretrained_model
(
self
,
model
,
state_dict
,
pretrained_model_path
,
ignore_mismatched_sizes
=
False
,
):
"""Load pretrained model.
Args:
model (libai.models): The model to be loaded.
state_dict (OrderedDict): state dict.
loaded_keys (list): keys of state dict.
pretrained_model_path (str): pretrained modelE path.
ignore_mismatched_sizes (bool):
Whether or not to raise an error if some of the weights
from the checkpoint do not have the same size as the
weights of the model, defaults to `False`.
"""
model_state_dict
=
model
.
state_dict
()
expected_keys
=
list
(
model_state_dict
.
keys
())
prefix
=
self
.
base_model_prefix_2
loaded_keys
=
state_dict
.
keys
()
if
len
(
prefix
)
>
0
:
has_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
loaded_keys
)
expects_prefix_module
=
any
(
s
.
startswith
(
prefix
)
for
s
in
expected_keys
)
else
:
has_prefix_module
=
False
expects_prefix_module
=
False
remove_prefix_from_model
=
not
has_prefix_module
and
expects_prefix_module
add_prefix_to_model
=
has_prefix_module
and
not
expects_prefix_module
if
remove_prefix_from_model
:
expected_keys_not_prefixed
=
[
s
for
s
in
expected_keys
if
not
s
.
startswith
(
prefix
)]
expected_keys
=
[
"."
.
join
(
s
.
split
(
"."
)[
1
:])
if
s
.
startswith
(
prefix
)
else
s
for
s
in
expected_keys
]
elif
add_prefix_to_model
:
expected_keys
=
[
"."
.
join
([
prefix
,
s
])
for
s
in
expected_keys
]
missing_keys
=
list
(
set
(
expected_keys
)
-
set
(
loaded_keys
))
unexpected_keys
=
list
(
set
(
loaded_keys
)
-
set
(
expected_keys
))
start_prefix
=
""
model_to_load
=
model
if
(
len
(
self
.
base_model_prefix_2
)
>
0
and
not
hasattr
(
model
,
self
.
base_model_prefix_2
)
and
has_prefix_module
):
start_prefix
=
self
.
base_model_prefix_2
+
"."
if
(
len
(
self
.
base_model_prefix_2
)
>
0
and
hasattr
(
model
,
self
.
base_model_prefix_2
)
and
not
has_prefix_module
):
model_to_load
=
getattr
(
model
,
self
.
base_model_prefix_2
)
if
any
(
key
in
expected_keys_not_prefixed
for
key
in
loaded_keys
):
raise
ValueError
(
"The state dict of the model you are loading is corrupted."
)
def
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
):
mismatched_keys
=
[]
if
ignore_mismatched_sizes
:
for
checkpoint_key
in
loaded_keys
:
model_key
=
checkpoint_key
if
remove_prefix_from_model
:
model_key
=
f
"
{
prefix
}
.
{
checkpoint_key
}
"
elif
add_prefix_to_model
:
model_key
=
"."
.
join
(
checkpoint_key
.
split
(
"."
)[
1
:])
if
(
model_key
in
model_state_dict
and
state_dict
[
checkpoint_key
].
shape
!=
model_state_dict
[
model_key
].
shape
):
mismatched_keys
.
append
(
(
checkpoint_key
,
state_dict
[
checkpoint_key
].
shape
,
model_state_dict
[
model_key
].
shape
,
)
)
del
state_dict
[
checkpoint_key
]
return
mismatched_keys
if
state_dict
is
not
None
:
mismatched_keys
=
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
)
error_msgs
=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
if
dist
.
get_local_rank
()
==
0
:
if
len
(
error_msgs
)
>
0
:
error_msg
=
"
\n\t
"
.
join
(
error_msgs
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of the model checkpoint at
{
pretrained_model_path
}
"
"were not used when "
f
"initializing
{
model
.
__class__
.
__name__
}
:
\n
{
unexpected_keys
}
\n
"
)
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing "
f
"
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized "
f
"from the model checkpoint at
{
pretrained_model_path
}
:
\n
"
f
"
{
missing_keys
}
\n
"
)
elif
len
(
mismatched_keys
)
==
0
:
logger
.
info
(
f
"All the weights of
{
model
.
__class__
.
__name__
}
were initialized "
f
"from the model checkpoint at
{
pretrained_model_path
}
.
\n
"
)
if
len
(
mismatched_keys
)
>
0
:
mismatched_warning
=
"
\n
"
.
join
(
[
f
"-
{
key
}
: found shape
{
shape1
}
in the checkpoint and
{
shape2
}
"
"in the model instantiated"
for
key
,
shape1
,
shape2
in
mismatched_keys
]
)
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized"
f
"from the model checkpoint at
{
pretrained_model_path
}
"
f
"and are newly initialized because the shapes did not"
f
"match:
\n
{
mismatched_warning
}
\n
"
)
return
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
class
ModelLoaderLiBai
(
ModelLoader
):
"""Class used to load `OneFlow` pretrained model.
Args:
model (libai.models): Model to be loaded in Libai.
libai_cfg (dict): The config of model in LiBai, you can import it from
`libai.config.configs.common.models`.
pretrained_model_path (str): The directory path of pretrained model,
which contains model weights file and config file.
output_loading_info (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary containing missing keys, unexpected keys
and error messages.
"""
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
None
# prefix in LiBai
def
_load_flow_state_dict
(
self
,
state_dict_file
):
# load oneflow_model
state_dict
=
flow
.
load
(
state_dict_file
,
global_src_rank
=
0
)
return
state_dict
def
load
(
self
):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from libai.config.configs.common.models.bert import cfg
>>> from model_loader import BertLoaderLiBai
>>> loder = BertLoaderLiBai(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loder.load()
"""
if
dist
.
is_main_process
():
assert
os
.
path
.
isdir
(
self
.
pretrained_model_path
),
f
"
{
self
.
pretrained_model_path
}
must be a directory"
flow_state_dict
=
self
.
_load_flow_state_dict
(
self
.
pretrained_model_path
)
# Instance model
if
isinstance
(
self
.
model
,
omegaconf
.
dictconfig
.
DictConfig
):
self
.
model
.
cfg
=
self
.
libai_cfg
self
.
model
=
build_model
(
self
.
model
)
else
:
self
.
model
=
build_model
(
LazyCall
(
self
.
model
)(
cfg
=
self
.
libai_cfg
))
# State_dict to global
self
.
_state_dict_to_global
(
flow_state_dict
,
mode
=
"libai"
)
# Load
(
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
,
)
=
self
.
_load_pretrained_model
(
self
.
model
,
flow_state_dict
,
self
.
pretrained_model_path
)
if
self
.
output_loading_info
:
loading_info
=
{
"missing_keys"
:
missing_keys
,
"unexpected_keys"
:
unexpected_keys
,
"mismatched_keys"
:
mismatched_keys
,
"error_msgs"
:
error_msgs
,
}
return
model
,
loading_info
return
model
class
ModelLoaderHuggerFace
(
ModelLoader
):
"""Class used to load the [`transformers`](https://huggingface.co/models)
pretrained model.
"""
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_1
=
None
# prefix in Transformers
self
.
base_model_prefix_2
=
None
# prefix in LiBai
self
.
origin_libai_cfg
=
copy
.
deepcopy
(
self
.
libai_cfg
)
self
.
changed_keys
=
set
()
# Store the changed configuration
def
_convert_tensor
(
self
,
tensor
):
"""Convert PyTorch tensor to OneFlow tensor.
Args:
tensor (torch.Tensor): The source tensor.
Returns:
flow.Tensor: The target tensor.
"""
tensor
=
tensor
.
float
()
return
flow
.
Tensor
(
tensor
.
detach
().
cpu
().
numpy
())
def
_convert_tensors
(
self
,
torch_state_dict
):
for
k
,
v
in
torch_state_dict
.
items
():
torch_state_dict
[
k
]
=
self
.
_convert_tensor
(
v
)
return
torch_state_dict
def
_fix_key
(
self
,
state_dict
):
"""Fix the key in state dict: Convert "gamma" to "weight" and "beta" to "bias".
Args:
state_dict (OrderedDict): state dict of pretrained model.
Returns:
OrderedDict: State dict after fix key.
"""
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
"gamma"
in
key
:
new_key
=
key
.
replace
(
"gamma"
,
"weight"
)
if
"beta"
in
key
:
new_key
=
key
.
replace
(
"beta"
,
"bias"
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
return
state_dict
def
_fix_qkv_ordering
(
self
,
qkv
,
head_size
,
num_heads
,
hidden_size
=
None
,
checkpoint_version
=
0.0
):
# TODO(xzp): Different versions checkpoint
hidden_size
=
(
head_size
*
num_heads
)
if
hidden_size
is
None
else
hidden_size
num_of_qkv
=
qkv
.
shape
[
0
]
//
(
head_size
*
num_heads
)
mode
=
"weight"
if
qkv
.
ndim
>
1
else
"bias"
if
mode
==
"weight"
:
qkv
=
qkv
.
view
([
num_of_qkv
,
num_heads
,
head_size
,
hidden_size
])
qkv
=
(
qkv
.
permute
(
1
,
0
,
2
,
3
)
.
contiguous
()
.
view
(
num_of_qkv
*
head_size
*
num_heads
,
hidden_size
)
)
elif
mode
==
"bias"
:
qkv
=
qkv
.
view
(
num_of_qkv
,
num_heads
,
head_size
)
qkv
=
qkv
.
permute
(
1
,
0
,
2
).
contiguous
().
view
(
-
1
)
return
qkv
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
):
"""A function used to convert the checkpoint file of Huggingface to LiBai.
Args:
torch_state_dict (OrderedDict): torch state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
raise
NotImplementedError
(
"_convert_state_dict not implemented"
)
def
_load_config_from_json
(
self
,
config_file
):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
raise
NotImplementedError
(
"_load_config_from_json not implemented"
)
def
_load_torch_state_dict
(
self
,
state_dict_file
):
try
:
import
torch
except
ImportError
:
raise
ImportError
(
"Load torch state dict need torch."
)
# load pytorch_model.bin
state_dict
=
torch
.
load
(
state_dict_file
,
map_location
=
"cpu"
)
return
state_dict
def
_update_cfg
(
self
,
keys_libai
,
value_target
):
"""Update the libai_cfg according to target_cfg.
Args:
keys_libai (str): The key of libai_cfg.
value_target (int | float): The value of target_cfg.
"""
if
keys_libai
not
in
self
.
libai_cfg
.
keys
():
return
if
self
.
libai_cfg
[
keys_libai
]
!=
value_target
:
self
.
libai_cfg
[
keys_libai
]
=
value_target
def
_update_cfg_log
(
self
):
if
dist
.
get_local_rank
()
==
0
:
for
key
in
sorted
(
self
.
libai_cfg
):
if
self
.
origin_libai_cfg
[
key
]
==
self
.
libai_cfg
[
key
]:
continue
self
.
changed_keys
.
add
(
key
)
temp_key
=
colored
(
key
,
"yellow"
)
logger
.
info
(
f
"changed libai model cfg
{
temp_key
}
: "
f
"
{
self
.
origin_libai_cfg
[
key
]
}
->
{
self
.
libai_cfg
[
key
]
}
"
)
logger
.
warning
(
"The following model configurations has been modified according "
"to `config.json` or kwargs:
\n
"
f
"
{
self
.
changed_keys
}
\n
"
)
if
dist
.
get_pipeline_parallel_size
()
>
1
:
logger
.
warning
(
colored
(
"If you use pipeline parallel, please "
"confirm the setting of `train.dist.pipeline_num_layers`
\n
"
,
"red"
,
)
)
def
load
(
self
):
"""Load model.
# For example:
# .. code-block:: python
>>> import libai
>>> from configs.common.models.bert import cfg
>>> from libai.models.utils import BertLoaderHugger
>>> loader = BertLoaderHugger(
libai.models.BertModel,
cfg,
'path/bert-base-chinese'
)
>>> bert = loader.load()
"""
if
dist
.
is_main_process
():
if
os
.
path
.
isdir
(
self
.
pretrained_model_path
):
# state_dict file pytorch
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
pretrained_model_path
,
WEIGHTS_NAME_PT
)):
model_file
=
os
.
path
.
join
(
self
.
pretrained_model_path
,
WEIGHTS_NAME_PT
)
else
:
raise
EnvironmentError
(
f
"Error no file named
{
WEIGHTS_NAME_PT
}
found"
f
"in directory
{
self
.
pretrained_model_path
}
."
)
# config file
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
pretrained_model_path
,
CONFIG_NAME
)):
config_file
=
os
.
path
.
join
(
self
.
pretrained_model_path
,
CONFIG_NAME
)
# Load config and update config.
self
.
_load_config_from_json
(
config_file
)
else
:
import
warnings
warnings
.
warn
(
f
"Error no file named
{
CONFIG_NAME
}
found in directory"
f
"
{
self
.
pretrained_model_path
}
"
,
RuntimeWarning
,
)
else
:
raise
EnvironmentError
(
f
"
{
self
.
pretrained_model_path
}
is not a directory."
)
logger
.
info
(
"loading torch model..."
)
torch_state_dict
=
self
.
_load_torch_state_dict
(
model_file
)
torch_state_dict
=
self
.
_fix_key
(
torch_state_dict
)
logger
.
info
(
"transfering torch model into oneflow model..."
)
flow_state_dict
=
self
.
_convert_tensors
(
torch_state_dict
)
flow_state_dict
=
self
.
_convert_state_dict
(
torch_state_dict
,
self
.
libai_cfg
)
else
:
flow_state_dict
=
None
self
.
libai_cfg
=
dist
.
broadcast_py_object
(
self
.
libai_cfg
,
src
=
0
)
# Instance model
logger
.
info
(
"building LiBai model..."
)
if
isinstance
(
self
.
model
,
omegaconf
.
dictconfig
.
DictConfig
):
self
.
model
.
cfg
=
self
.
libai_cfg
self
.
model
=
build_model
(
self
.
model
)
else
:
self
.
model
=
build_model
(
LazyCall
(
self
.
model
)(
cfg
=
self
.
libai_cfg
))
# State_dict to global
logger
.
info
(
"transfering state_dict local to global..."
)
flow_state_dict
=
self
.
_state_dict_to_global
(
flow_state_dict
,
mode
=
"pytorch"
)
logger
.
info
(
"loading model weights into LiBai..."
)
# Load
(
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
,
)
=
self
.
_load_pretrained_model
(
self
.
model
,
flow_state_dict
,
self
.
pretrained_model_path
)
if
self
.
output_loading_info
:
loading_info
=
{
"missing_keys"
:
missing_keys
,
"unexpected_keys"
:
unexpected_keys
,
"mismatched_keys"
:
mismatched_keys
,
"error_msgs"
:
error_msgs
,
}
return
model
,
loading_info
return
model
libai/models/utils/model_loader/bert_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
oneflow
as
flow
from
.base_loader
import
ModelLoaderHuggerFace
,
ModelLoaderLiBai
class
BertLoaderHuggerFace
(
ModelLoaderHuggerFace
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
"""NOTE: base_model_prefix_1 is BERT's prefix in Transformers.
base_model_prefix_2 is BERT's prefix in LiBai."""
self
.
base_model_prefix_1
=
"bert"
self
.
base_model_prefix_2
=
"bert"
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict
=
flow_state_dict
.
copy
()
# Get configs
num_heads
=
cfg
.
get
(
"num_attention_heads"
)
hidden_size
=
cfg
.
get
(
"hidden_size"
)
layers
=
cfg
.
get
(
"hidden_layers"
)
head_size
=
int
(
hidden_size
/
num_heads
)
# prefix
has_prefix
=
any
(
s
.
startswith
(
self
.
base_model_prefix_1
)
for
s
in
oneflow_state_dict
)
prefix
=
"bert."
if
has_prefix
else
""
index_idx
=
3
if
has_prefix
else
2
qkv_idx
=
6
if
has_prefix
else
5
old_keys
=
oneflow_state_dict
.
keys
()
for
key
in
list
(
old_keys
):
# Convert bert's embedding layers
if
"embeddings"
in
key
:
if
"word_embeddings"
in
key
:
new_key
=
key
.
replace
(
"word_embeddings"
,
"vocab_embeddings"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"token_type_embeddings"
in
key
:
new_key
=
key
.
replace
(
"token_type_embeddings"
,
"tokentype_embeddings"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm.weight"
in
key
:
new_key
=
prefix
+
"encoders.0.input_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm.bias"
in
key
:
new_key
=
prefix
+
"encoders.0.input_layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
[
key
]
# Convert bert's attention layers
elif
"attention"
in
key
:
if
"self"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
(
prefix
+
"encoders."
+
index
+
".self_attention.query_key_value.weight"
in
oneflow_state_dict
.
keys
()
):
continue
q_w
=
key
.
replace
(
key
.
split
(
"."
)[
qkv_idx
],
"query"
).
replace
(
key
.
split
(
"."
)[
qkv_idx
+
1
],
"weight"
)
k_w
=
q_w
.
replace
(
"query"
,
"key"
)
v_w
=
q_w
.
replace
(
"query"
,
"value"
)
q_b
=
q_w
.
replace
(
"weight"
,
"bias"
)
k_b
=
k_w
.
replace
(
"weight"
,
"bias"
)
v_b
=
v_w
.
replace
(
"weight"
,
"bias"
)
qkv_w
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_w
),
oneflow_state_dict
.
pop
(
k_w
),
oneflow_state_dict
.
pop
(
v_w
),
),
dim
=
0
,
)
qkv_b
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_b
),
oneflow_state_dict
.
pop
(
k_b
),
oneflow_state_dict
.
pop
(
v_b
),
),
dim
=-
1
,
)
qkv_w
=
self
.
_fix_qkv_ordering
(
qkv_w
,
head_size
,
num_heads
)
qkv_b
=
self
.
_fix_qkv_ordering
(
qkv_b
,
head_size
,
num_heads
)
new_key
=
(
prefix
+
"encoders."
+
index
+
".self_attention.query_key_value.weight"
)
oneflow_state_dict
[
new_key
]
=
qkv_w
new_key
=
prefix
+
"encoders."
+
index
+
".self_attention.query_key_value.bias"
oneflow_state_dict
[
new_key
]
=
qkv_b
elif
"output"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
"dense"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix
+
"encoders."
+
index
+
".self_attention.dense.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
prefix
+
"encoders."
+
index
+
".self_attention.dense.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm"
in
key
:
if
"weight"
in
key
:
new_key
=
(
prefix
+
"encoders."
+
index
+
".post_attention_layernorm.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
(
prefix
+
"encoders."
+
index
+
".post_attention_layernorm.bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert bert's intermediate layers
elif
"intermediate"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
(
prefix
+
"encoders."
+
index
+
".mlp.dense_h_to_4h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
if
"weight"
in
key
:
w
=
key
b
=
key
.
replace
(
"weight"
,
"bias"
)
new_key
=
prefix
+
"encoders."
+
index
+
".mlp.dense_h_to_4h.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
# Convert bert's output layers
elif
"output"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
"dense.weight"
in
key
:
if
(
prefix
+
"encoders."
+
index
+
".mlp.dense_4h_to_h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
prefix
+
"encoders."
+
index
+
".mlp.dense_4h_to_h.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"LayerNorm.weight"
in
key
:
if
(
prefix
+
"encoders."
+
str
(
int
(
index
)
+
1
)
+
".input_layernorm.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
if
index
==
str
(
layers
-
1
):
new_key
=
prefix
+
"final_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
continue
new_key
=
prefix
+
"encoders."
+
str
(
int
(
index
)
+
1
)
+
".input_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
# Convert bert's pooler layers
elif
"pooler"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix
+
"pooler.dense.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
prefix
+
"pooler.dense.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert cls_head layers
elif
"cls"
in
key
:
if
"predictions.bias"
in
key
:
new_key
=
"cls_head.lm_logits.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"dense.weight"
in
key
:
new_key
=
"cls_head.predictions.dense.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"dense.bias"
in
key
:
new_key
=
"cls_head.predictions.dense.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm.weight"
in
key
:
new_key
=
"cls_head.predictions.layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm.bias"
in
key
:
new_key
=
"cls_head.predictions.layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"seq_relationship"
in
key
:
new_key
=
key
.
replace
(
"cls"
,
"cls_head"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
.
pop
(
key
)
return
oneflow_state_dict
def
_load_config_from_json
(
self
,
config_file
):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with
open
(
config_file
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
f
:
cfg_dict
=
json
.
load
(
f
)
# update libai_cfg by config.json
self
.
_update_cfg
(
"vocab_size"
,
cfg_dict
[
"vocab_size"
])
self
.
_update_cfg
(
"hidden_size"
,
cfg_dict
[
"hidden_size"
])
self
.
_update_cfg
(
"hidden_layers"
,
cfg_dict
[
"num_hidden_layers"
])
self
.
_update_cfg
(
"num_attention_heads"
,
cfg_dict
[
"num_attention_heads"
])
self
.
_update_cfg
(
"intermediate_size"
,
cfg_dict
[
"intermediate_size"
])
self
.
_update_cfg
(
"hidden_dropout_prob"
,
cfg_dict
[
"hidden_dropout_prob"
])
self
.
_update_cfg
(
"attention_probs_dropout_prob"
,
cfg_dict
[
"attention_probs_dropout_prob"
])
self
.
_update_cfg
(
"max_position_embeddings"
,
cfg_dict
[
"max_position_embeddings"
])
self
.
_update_cfg
(
"num_tokentypes"
,
cfg_dict
[
"type_vocab_size"
])
self
.
_update_cfg
(
"initializer_range"
,
cfg_dict
[
"initializer_range"
])
self
.
_update_cfg
(
"layernorm_eps"
,
cfg_dict
[
"layer_norm_eps"
])
# update libai_cfg by kwargs
for
k
,
v
in
self
.
kwargs
.
items
():
self
.
_update_cfg
(
k
,
v
)
# use original BERT residual connection ordering
self
.
libai_cfg
.
apply_residual_post_layernorm
=
True
self
.
_update_cfg_log
()
class
BertLoaderLiBai
(
ModelLoaderLiBai
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
"bert"
libai/models/utils/model_loader/gpt_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
.base_loader
import
ModelLoaderHuggerFace
,
ModelLoaderLiBai
class
GPT2LoaderHuggerFace
(
ModelLoaderHuggerFace
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
"""NOTE: base_model_prefix_1 is GPT's prefix in Transformers.
base_model_prefix_2 is GPT's prefix in LiBai."""
self
.
base_model_prefix_1
=
"transformer"
self
.
base_model_prefix_2
=
"GPT_model"
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict
=
flow_state_dict
.
copy
()
old_keys
=
list
(
oneflow_state_dict
.
keys
())
# Get configs
num_heads
=
cfg
.
get
(
"num_attention_heads"
)
hidden_size
=
cfg
.
get
(
"hidden_size"
)
head_size
=
int
(
hidden_size
/
num_heads
)
# prefix
has_prefix
=
any
(
s
.
startswith
(
self
.
base_model_prefix_1
)
for
s
in
oneflow_state_dict
)
prefix1
=
self
.
base_model_prefix_1
+
"."
if
has_prefix
else
""
prefix2
=
"GPT_model.transformer."
layer_idx
=
2
if
has_prefix
else
1
# Convert Embedding layers.
new_key
=
"GPT_model.embeddings.token_embeddings.weight"
old_keys
.
remove
(
prefix1
+
"wte.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
prefix1
+
"wte.weight"
)
new_key
=
"GPT_model.embeddings.position_embeddings.weight"
old_keys
.
remove
(
prefix1
+
"wpe.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
prefix1
+
"wpe.weight"
)
for
key
in
old_keys
:
keys
=
key
.
split
(
"."
)
if
layer_idx
>=
len
(
keys
):
continue
layer
=
keys
[
layer_idx
]
# Convert transformer layers.
if
"h."
in
key
:
if
"ln_1"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".input_layernorm.weight"
else
:
new_key
=
prefix2
+
"layers."
+
layer
+
".input_layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"ln_2"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".post_attention_layernorm.weight"
else
:
new_key
=
prefix2
+
"layers."
+
layer
+
".post_attention_layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"attn"
in
key
:
if
"c_attn"
in
key
:
if
"weight"
in
key
:
new_key
=
(
prefix2
+
"layers."
+
layer
+
".self_attention.query_key_value.weight"
)
else
:
new_key
=
(
prefix2
+
"layers."
+
layer
+
".self_attention.query_key_value.bias"
)
qkv
=
oneflow_state_dict
.
pop
(
key
)
if
qkv
.
ndim
>
1
:
qkv
=
qkv
.
transpose
(
1
,
0
)
qkv
=
self
.
_fix_qkv_ordering
(
qkv
,
head_size
,
num_heads
)
oneflow_state_dict
[
new_key
]
=
qkv
elif
"c_proj"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".self_attention.dense.weight"
elif
"bias"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".self_attention.dense.bias"
value
=
oneflow_state_dict
.
pop
(
key
)
if
value
.
ndim
>
1
:
value
=
value
.
transpose
(
1
,
0
)
oneflow_state_dict
[
new_key
]
=
value
elif
"mlp"
in
key
:
if
"c_fc"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".mlp.dense_h_to_4h.weight"
elif
"bias"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".mlp.dense_h_to_4h.bias"
value
=
oneflow_state_dict
.
pop
(
key
)
if
value
.
ndim
>
1
:
value
=
value
.
transpose
(
1
,
0
)
oneflow_state_dict
[
new_key
]
=
value
elif
"c_proj"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".mlp.dense_4h_to_h.weight"
elif
"bias"
in
key
:
new_key
=
prefix2
+
"layers."
+
layer
+
".mlp.dense_4h_to_h.bias"
value
=
oneflow_state_dict
.
pop
(
key
)
if
value
.
ndim
>
1
:
value
=
value
.
transpose
(
1
,
0
)
oneflow_state_dict
[
new_key
]
=
value
elif
"ln_f"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix2
+
"layernorm_f.weight"
elif
"bias"
in
key
:
new_key
=
prefix2
+
"layernorm_f.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
return
oneflow_state_dict
def
_load_config_from_json
(
self
,
config_file
):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with
open
(
config_file
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
f
:
cfg_dict
=
json
.
load
(
f
)
# update libai_cfg by config.json
self
.
_update_cfg
(
"hidden_layers"
,
cfg_dict
[
"n_layer"
])
self
.
_update_cfg
(
"hidden_size"
,
cfg_dict
[
"n_embd"
])
self
.
_update_cfg
(
"num_attention_heads"
,
cfg_dict
[
"n_head"
])
self
.
_update_cfg
(
"max_seq_length"
,
cfg_dict
[
"n_positions"
])
self
.
_update_cfg
(
"embedding_dropout_prob"
,
cfg_dict
[
"embd_pdrop"
])
self
.
_update_cfg
(
"attention_dropout_prob"
,
cfg_dict
[
"attn_pdrop"
])
self
.
_update_cfg
(
"output_dropout_prob"
,
cfg_dict
[
"resid_pdrop"
])
self
.
_update_cfg
(
"layernorm_epsilon"
,
cfg_dict
[
"layer_norm_epsilon"
])
self
.
_update_cfg
(
"vocab_size"
,
cfg_dict
[
"vocab_size"
])
self
.
_update_cfg
(
"initializer_range"
,
cfg_dict
[
"initializer_range"
])
self
.
_update_cfg
(
"ffn_hidden_size"
,
cfg_dict
.
get
(
"n_inner"
)
if
cfg_dict
.
get
(
"n_inner"
)
is
not
None
else
4
*
self
.
libai_cfg
[
"hidden_size"
],
)
# update libai_cfg by kwargs
for
k
,
v
in
self
.
kwargs
.
items
():
self
.
_update_cfg
(
k
,
v
)
self
.
_update_cfg_log
()
class
GPT2LoaderLiBai
(
ModelLoaderLiBai
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
"GPT_model"
libai/models/utils/model_loader/roberta_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
oneflow
as
flow
from
.bert_loader
import
BertLoaderHuggerFace
,
BertLoaderLiBai
class
RobertaLoaderHuggerFace
(
BertLoaderHuggerFace
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
"""NOTE: base_model_prefix_1 is RoBERTa's prefix in Transformers,
base_model_prefix_2 is RoBERTa's prefix in LiBai."""
self
.
base_model_prefix_1
=
"roberta"
self
.
base_model_prefix_2
=
"roberta"
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict
=
flow_state_dict
.
copy
()
# Get configs
num_heads
=
cfg
.
get
(
"num_attention_heads"
)
hidden_size
=
cfg
.
get
(
"hidden_size"
)
layers
=
cfg
.
get
(
"hidden_layers"
)
head_size
=
int
(
hidden_size
/
num_heads
)
# prefix
has_prefix
=
any
(
s
.
startswith
(
self
.
base_model_prefix_1
)
for
s
in
oneflow_state_dict
)
prefix
=
"roberta."
if
has_prefix
else
""
index_idx
=
3
if
has_prefix
else
2
qkv_idx
=
6
if
has_prefix
else
5
old_keys
=
oneflow_state_dict
.
keys
()
for
key
in
list
(
old_keys
):
# Convert roberta's embedding layers
if
"embeddings"
in
key
:
if
"word_embeddings"
in
key
:
new_key
=
key
.
replace
(
"word_embeddings"
,
"vocab_embeddings"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"token_type_embeddings"
in
key
:
new_key
=
key
.
replace
(
"token_type_embeddings"
,
"tokentype_embeddings"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm.weight"
in
key
:
new_key
=
prefix
+
"encoders.0.input_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm.bias"
in
key
:
new_key
=
prefix
+
"encoders.0.input_layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
[
key
]
# Convert roberta's attention layers
elif
"attention"
in
key
:
if
"self"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
(
prefix
+
"encoders."
+
index
+
".self_attention.query_key_value.weight"
in
oneflow_state_dict
.
keys
()
):
continue
q_w
=
key
.
replace
(
key
.
split
(
"."
)[
qkv_idx
],
"query"
).
replace
(
key
.
split
(
"."
)[
qkv_idx
+
1
],
"weight"
)
k_w
=
q_w
.
replace
(
"query"
,
"key"
)
v_w
=
q_w
.
replace
(
"query"
,
"value"
)
q_b
=
q_w
.
replace
(
"weight"
,
"bias"
)
k_b
=
k_w
.
replace
(
"weight"
,
"bias"
)
v_b
=
v_w
.
replace
(
"weight"
,
"bias"
)
qkv_w
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_w
),
oneflow_state_dict
.
pop
(
k_w
),
oneflow_state_dict
.
pop
(
v_w
),
),
dim
=
0
,
)
qkv_b
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_b
),
oneflow_state_dict
.
pop
(
k_b
),
oneflow_state_dict
.
pop
(
v_b
),
),
dim
=-
1
,
)
qkv_w
=
self
.
_fix_qkv_ordering
(
qkv_w
,
head_size
,
num_heads
)
qkv_b
=
self
.
_fix_qkv_ordering
(
qkv_b
,
head_size
,
num_heads
)
new_key
=
(
prefix
+
"encoders."
+
index
+
".self_attention.query_key_value.weight"
)
oneflow_state_dict
[
new_key
]
=
qkv_w
new_key
=
prefix
+
"encoders."
+
index
+
".self_attention.query_key_value.bias"
oneflow_state_dict
[
new_key
]
=
qkv_b
elif
"output"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
"dense"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix
+
"encoders."
+
index
+
".self_attention.dense.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
prefix
+
"encoders."
+
index
+
".self_attention.dense.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"LayerNorm"
in
key
:
if
"weight"
in
key
:
new_key
=
(
prefix
+
"encoders."
+
index
+
".post_attention_layernorm.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
(
prefix
+
"encoders."
+
index
+
".post_attention_layernorm.bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert roberta's intermediate layers
elif
"intermediate"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
(
prefix
+
"encoders."
+
index
+
".mlp.dense_h_to_4h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
if
"weight"
in
key
:
w
=
key
b
=
key
.
replace
(
"weight"
,
"bias"
)
new_key
=
prefix
+
"encoders."
+
index
+
".mlp.dense_h_to_4h.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
# Convert roberta's output layers
elif
"output"
in
key
:
index
=
key
.
split
(
"."
)[
index_idx
]
if
"dense.weight"
in
key
:
if
(
prefix
+
"encoders."
+
index
+
".mlp.dense_4h_to_h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
prefix
+
"encoders."
+
index
+
".mlp.dense_4h_to_h.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"LayerNorm.weight"
in
key
:
if
(
prefix
+
"encoders."
+
str
(
int
(
index
)
+
1
)
+
".input_layernorm.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
if
index
==
str
(
layers
-
1
):
new_key
=
prefix
+
"final_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
continue
new_key
=
prefix
+
"encoders."
+
str
(
int
(
index
)
+
1
)
+
".input_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
# Convert roberta's pooler layers
elif
"pooler"
in
key
:
if
"weight"
in
key
:
new_key
=
prefix
+
"pooler.dense.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
prefix
+
"pooler.dense.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert lm_head layers
elif
"lm_head"
in
key
:
if
"layer_norm.weight"
in
key
:
new_key
=
"lm_head.layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"layer_norm.bias"
in
key
:
new_key
=
"lm_head.layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"seq_relationship"
in
key
:
new_key
=
key
.
replace
(
"cls"
,
"cls_head"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"lm_head.bias"
in
key
:
new_key
=
"lm_head.lm_logits.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
.
pop
(
key
)
return
oneflow_state_dict
class
RobertaLoaderLiBai
(
BertLoaderLiBai
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
"roberta"
libai/models/utils/model_loader/swin_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
oneflow
as
flow
from
.base_loader
import
ModelLoaderHuggerFace
,
ModelLoaderLiBai
class
SwinLoaderHuggerFace
(
ModelLoaderHuggerFace
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
"""NOTE: base_model_prefix_1 is SWIN's prefix in Transformers.
base_model_prefix_2 is SWIN's prefix in LiBai."""
self
.
base_model_prefix_1
=
"swin"
self
.
base_model_prefix_2
=
""
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
=
None
):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict
=
flow_state_dict
.
copy
()
# prefix
has_prefix
=
any
(
s
.
startswith
(
self
.
base_model_prefix_1
)
for
s
in
oneflow_state_dict
)
index_idx_1
=
3
if
has_prefix
else
2
index_idx_2
=
5
if
has_prefix
else
4
old_keys
=
oneflow_state_dict
.
keys
()
for
key
in
list
(
old_keys
):
# Convert swin's embedding layers
if
"embeddings"
in
key
:
if
"patch_embeddings.projection"
in
key
:
if
"weight"
in
key
:
new_key
=
"patch_embed.proj.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"patch_embed.proj.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"norm"
in
key
:
if
"weight"
in
key
:
new_key
=
"patch_embed.norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"patch_embed.norm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert swin's layernorm layers
elif
"layernorm_before"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"weight"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm1.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm1.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"layernorm_after"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"weight"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm2.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm2.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert swin's attention layers
elif
"attention"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"self"
in
key
:
if
(
"relative_position_bias_table"
in
key
):
# convert relative_position_bias_table but not index
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.relative_position_bias_table"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"relative_position_index"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.relative_position_index"
)
oneflow_state_dict
.
pop
(
key
)
else
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.qkv.weight"
in
oneflow_state_dict
.
keys
()
):
continue
q_w
=
key
k_w
=
q_w
.
replace
(
"query"
,
"key"
)
v_w
=
q_w
.
replace
(
"query"
,
"value"
)
q_b
=
q_w
.
replace
(
"weight"
,
"bias"
)
k_b
=
k_w
.
replace
(
"weight"
,
"bias"
)
v_b
=
v_w
.
replace
(
"weight"
,
"bias"
)
qkv_w
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_w
),
oneflow_state_dict
.
pop
(
k_w
),
oneflow_state_dict
.
pop
(
v_w
),
),
dim
=
0
,
)
qkv_b
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_b
),
oneflow_state_dict
.
pop
(
k_b
),
oneflow_state_dict
.
pop
(
v_b
),
),
dim
=-
1
,
)
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.qkv.weight"
)
oneflow_state_dict
[
new_key
]
=
qkv_w
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
qkv_b
elif
"output"
in
key
:
if
"dense"
in
key
:
if
"weight"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.proj.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
if
"bias"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.proj.bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"intermediate"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"weight"
in
key
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_h_to_4h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
key
.
replace
(
"weight"
,
"bias"
)
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"output"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"dense.weight"
in
key
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_4h_to_h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"downsample"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
if
"reduction.weight"
in
key
:
new_key
=
"layers."
+
index_layer
+
".downsample.reduction.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"norm"
in
key
:
if
(
"layers."
+
index_layer
+
".downsample.norm.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
"layers."
+
index_layer
+
".downsample.norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"layernorm"
in
key
:
if
"weight"
in
key
:
new_key
=
"norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"norm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"classifier"
in
key
:
if
"weight"
in
key
:
new_key
=
"head.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"head.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
.
pop
(
key
)
return
oneflow_state_dict
def
_load_config_from_json
(
self
,
config_file
):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with
open
(
config_file
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
f
:
cfg_dict
=
json
.
load
(
f
)
# update libai_cfg by config.json
self
.
_update_cfg
(
"img_size"
,
cfg_dict
[
"image_size"
])
self
.
_update_cfg
(
"patch_size"
,
cfg_dict
[
"patch_size"
])
self
.
_update_cfg
(
"embed_dim"
,
cfg_dict
[
"embed_dim"
])
self
.
_update_cfg
(
"depths"
,
cfg_dict
[
"depths"
])
self
.
_update_cfg
(
"num_heads"
,
cfg_dict
[
"num_heads"
])
self
.
_update_cfg
(
"window_size"
,
cfg_dict
[
"window_size"
])
self
.
_update_cfg
(
"mlp_ratio"
,
cfg_dict
[
"mlp_ratio"
])
self
.
_update_cfg
(
"qkv_bias"
,
cfg_dict
[
"qkv_bias"
])
self
.
_update_cfg
(
"drop_path_rate"
,
cfg_dict
[
"drop_path_rate"
])
# update libai_cfg by kwargs
for
k
,
v
in
self
.
kwargs
.
items
():
self
.
_update_cfg
(
k
,
v
)
self
.
_update_cfg_log
()
class
SwinLoaderLiBai
(
ModelLoaderLiBai
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
""
libai/models/utils/model_loader/swinv2_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
oneflow
as
flow
from
.base_loader
import
ModelLoaderHuggerFace
,
ModelLoaderLiBai
class
SwinV2LoaderHuggerFace
(
ModelLoaderHuggerFace
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
"""NOTE: base_model_prefix_1 is SWINV2's prefix in Transformers.
base_model_prefix_2 is SWINV2's prefix in LiBai."""
self
.
base_model_prefix_1
=
"swinv2"
self
.
base_model_prefix_2
=
""
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
=
None
):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict
=
flow_state_dict
.
copy
()
# prefix
has_prefix
=
any
(
s
.
startswith
(
self
.
base_model_prefix_1
)
for
s
in
oneflow_state_dict
)
index_idx_1
=
3
if
has_prefix
else
2
index_idx_2
=
5
if
has_prefix
else
4
old_keys
=
oneflow_state_dict
.
keys
()
for
key
in
list
(
old_keys
):
# Convert swinv2's embedding layers
if
"embeddings"
in
key
:
if
"patch_embeddings.projection"
in
key
:
if
"weight"
in
key
:
new_key
=
"patch_embed.proj.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
if
"bias"
in
key
:
new_key
=
"patch_embed.proj.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"norm"
in
key
:
if
"weight"
in
key
:
new_key
=
"patch_embed.norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
if
"bias"
in
key
:
new_key
=
"patch_embed.norm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert swinv2's layernorm layers
elif
"layernorm_before"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"weight"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm1.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm1.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"layernorm_after"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"weight"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm2.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".norm2.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert swinv2's attention layers
elif
"attention"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"self"
in
key
:
if
(
"relative_position_bias_table"
in
key
):
# convert relative_position_bias_table but not index
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.relative_position_bias_table"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"relative_position_index"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.relative_position_index"
)
oneflow_state_dict
.
pop
(
key
)
elif
"continuous_position_bias_mlp"
in
key
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.cpb_mlp"
+
".0.weight"
)
in
oneflow_state_dict
.
keys
():
continue
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.cpb_mlp"
)
m_1_w
=
key
m_1_b
=
key
.
replace
(
".0.weight"
,
".0.bias"
)
m_2_w
=
key
.
replace
(
".0.weight"
,
".2.weight"
)
oneflow_state_dict
[
new_key
+
".0.weight"
]
=
oneflow_state_dict
.
pop
(
m_1_w
)
oneflow_state_dict
[
new_key
+
".0.bias"
]
=
oneflow_state_dict
.
pop
(
m_1_b
)
oneflow_state_dict
[
new_key
+
".2.weight"
]
=
oneflow_state_dict
.
pop
(
m_2_w
)
elif
"logit_scale"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.logit_scale"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)[
None
,
...]
else
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.qkv.weight"
in
oneflow_state_dict
.
keys
()
):
continue
q_w
=
key
k_w
=
q_w
.
replace
(
"query"
,
"key"
)
v_w
=
q_w
.
replace
(
"query"
,
"value"
)
q_b
=
q_w
.
replace
(
"weight"
,
"bias"
)
v_b
=
v_w
.
replace
(
"weight"
,
"bias"
)
qkv_w
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_w
),
oneflow_state_dict
.
pop
(
k_w
),
oneflow_state_dict
.
pop
(
v_w
),
),
dim
=
0
,
)
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.qkv.weight"
)
oneflow_state_dict
[
new_key
]
=
qkv_w
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.q_bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
q_b
)
new_key
=
new_key
.
replace
(
"q_bias"
,
"v_bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
v_b
)
elif
"output"
in
key
:
if
"dense"
in
key
:
if
"weight"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.proj.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
if
"bias"
in
key
:
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".attn.proj.bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"intermediate"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"weight"
in
key
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_h_to_4h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
key
.
replace
(
"weight"
,
"bias"
)
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_h_to_4h.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"output"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
index_block
=
key
.
split
(
"."
)[
index_idx_2
]
if
"dense.weight"
in
key
:
if
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_4h_to_h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
(
"layers."
+
index_layer
+
".blocks."
+
index_block
+
".mlp.dense_4h_to_h.weight"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"downsample"
in
key
:
index_layer
=
key
.
split
(
"."
)[
index_idx_1
]
if
"reduction.weight"
in
key
:
new_key
=
"layers."
+
index_layer
+
".downsample.reduction.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"norm"
in
key
:
if
(
"layers."
+
index_layer
+
".downsample.norm.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
"layers."
+
index_layer
+
".downsample.norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"layernorm"
in
key
:
if
"weight"
in
key
:
new_key
=
"norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"norm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"classifier"
in
key
:
if
"weight"
in
key
:
new_key
=
"head.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"head.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
.
pop
(
key
)
return
oneflow_state_dict
def
_load_config_from_json
(
self
,
config_file
):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with
open
(
config_file
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
f
:
cfg_dict
=
json
.
load
(
f
)
# update libai_cfg by config.json
self
.
_update_cfg
(
"img_size"
,
cfg_dict
[
"image_size"
])
self
.
_update_cfg
(
"patch_size"
,
cfg_dict
[
"patch_size"
])
self
.
_update_cfg
(
"embed_dim"
,
cfg_dict
[
"embed_dim"
])
self
.
_update_cfg
(
"depths"
,
cfg_dict
[
"depths"
])
self
.
_update_cfg
(
"num_heads"
,
cfg_dict
[
"num_heads"
])
self
.
_update_cfg
(
"window_size"
,
cfg_dict
[
"window_size"
])
self
.
_update_cfg
(
"mlp_ratio"
,
cfg_dict
[
"mlp_ratio"
])
self
.
_update_cfg
(
"qkv_bias"
,
cfg_dict
[
"qkv_bias"
])
self
.
_update_cfg
(
"drop_path_rate"
,
cfg_dict
[
"drop_path_rate"
])
self
.
_update_cfg
(
"pretrained_window_sizes"
,
cfg_dict
[
"pretrained_window_sizes"
])
# update libai_cfg by kwargs
for
k
,
v
in
self
.
kwargs
.
items
():
self
.
_update_cfg
(
k
,
v
)
self
.
_update_cfg_log
()
class
SwinV2LoaderLiBai
(
ModelLoaderLiBai
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
""
libai/models/utils/model_loader/vit_loader.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
oneflow
as
flow
from
.base_loader
import
ModelLoaderHuggerFace
,
ModelLoaderLiBai
class
ViTLoaderHuggerFace
(
ModelLoaderHuggerFace
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
"""NOTE: base_model_prefix_1 is ViT's prefix in Transformers.
base_model_prefix_2 is ViT's prefix in LiBai."""
self
.
base_model_prefix_1
=
"vit"
self
.
base_model_prefix_2
=
""
def
_convert_state_dict
(
self
,
flow_state_dict
,
cfg
=
None
):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict
=
flow_state_dict
.
copy
()
# Get configs
num_heads
=
cfg
.
get
(
"num_heads"
)
hidden_size
=
cfg
.
get
(
"embed_dim"
)
head_size
=
int
(
hidden_size
/
num_heads
)
# prefix
has_prefix
=
any
(
s
.
startswith
(
self
.
base_model_prefix_1
)
for
s
in
oneflow_state_dict
)
index_idx
=
3
if
has_prefix
else
2
old_keys
=
oneflow_state_dict
.
keys
()
for
key
in
list
(
old_keys
):
# Convert vit's embedding layers
if
"embeddings"
in
key
:
if
"cls_token"
in
key
:
new_key
=
"cls_token"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"position_embeddings"
in
key
:
new_key
=
"pos_embed"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"patch_embeddings.projection"
in
key
:
if
"weight"
in
key
:
new_key
=
"patch_embed.proj.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"patch_embed.proj.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert vit's layernorm layers
elif
"layernorm_before"
in
key
:
index_block
=
key
.
split
(
"."
)[
index_idx
]
if
"weight"
in
key
:
new_key
=
"blocks."
+
index_block
+
".input_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"blocks."
+
index_block
+
".input_layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"layernorm_after"
in
key
:
index_block
=
key
.
split
(
"."
)[
index_idx
]
if
"weight"
in
key
:
new_key
=
"blocks."
+
index_block
+
".post_attention_layernorm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"blocks."
+
index_block
+
".post_attention_layernorm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
# Convert vit's attention layers
elif
"attention"
in
key
:
index_block
=
key
.
split
(
"."
)[
index_idx
]
if
"attention.attention"
in
key
:
if
(
"blocks."
+
index_block
+
".self_attention.query_key_value.weight"
in
oneflow_state_dict
.
keys
()
):
continue
q_w
=
key
k_w
=
q_w
.
replace
(
"query"
,
"key"
)
v_w
=
q_w
.
replace
(
"query"
,
"value"
)
q_b
=
q_w
.
replace
(
"weight"
,
"bias"
)
k_b
=
k_w
.
replace
(
"weight"
,
"bias"
)
v_b
=
v_w
.
replace
(
"weight"
,
"bias"
)
qkv_w
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_w
),
oneflow_state_dict
.
pop
(
k_w
),
oneflow_state_dict
.
pop
(
v_w
),
),
dim
=
0
,
)
qkv_b
=
flow
.
cat
(
(
oneflow_state_dict
.
pop
(
q_b
),
oneflow_state_dict
.
pop
(
k_b
),
oneflow_state_dict
.
pop
(
v_b
),
),
dim
=-
1
,
)
qkv_w
=
self
.
_fix_qkv_ordering
(
qkv_w
,
head_size
,
num_heads
)
qkv_b
=
self
.
_fix_qkv_ordering
(
qkv_b
,
head_size
,
num_heads
)
new_key
=
"blocks."
+
index_block
+
".self_attention.query_key_value.weight"
oneflow_state_dict
[
new_key
]
=
qkv_w
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
qkv_b
elif
"output"
in
key
:
if
"dense"
in
key
:
if
"weight"
in
key
:
new_key
=
"blocks."
+
index_block
+
".self_attention.dense.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
if
"bias"
in
key
:
new_key
=
"blocks."
+
index_block
+
".self_attention.dense.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"intermediate"
in
key
:
index_block
=
key
.
split
(
"."
)[
index_idx
]
if
"weight"
in
key
:
if
(
"blocks."
+
index_block
+
".mlp.dense_h_to_4h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
key
.
replace
(
"weight"
,
"bias"
)
new_key
=
"blocks."
+
index_block
+
".mlp.dense_h_to_4h.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"output"
in
key
:
index_block
=
key
.
split
(
"."
)[
index_idx
]
if
"dense.weight"
in
key
:
if
(
"blocks."
+
index_block
+
".mlp.dense_4h_to_h.weight"
in
oneflow_state_dict
.
keys
()
):
continue
w
=
key
b
=
w
.
replace
(
"weight"
,
"bias"
)
new_key
=
"blocks."
+
index_block
+
".mlp.dense_4h_to_h.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
w
)
new_key
=
new_key
.
replace
(
"weight"
,
"bias"
)
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
b
)
elif
"layernorm"
in
key
:
if
"weight"
in
key
:
new_key
=
"norm.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"norm.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"classifier"
in
key
:
if
"weight"
in
key
:
new_key
=
"head.weight"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
elif
"bias"
in
key
:
new_key
=
"head.bias"
oneflow_state_dict
[
new_key
]
=
oneflow_state_dict
.
pop
(
key
)
else
:
oneflow_state_dict
[
key
]
=
oneflow_state_dict
.
pop
(
key
)
return
oneflow_state_dict
def
_load_config_from_json
(
self
,
config_file
):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with
open
(
config_file
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
f
:
cfg_dict
=
json
.
load
(
f
)
# update libai_cfg by config.json
self
.
_update_cfg
(
"img_size"
,
cfg_dict
[
"image_size"
])
self
.
_update_cfg
(
"patch_size"
,
cfg_dict
[
"patch_size"
])
self
.
_update_cfg
(
"in_chans"
,
cfg_dict
[
"num_channels"
])
self
.
_update_cfg
(
"embed_dim"
,
cfg_dict
[
"hidden_size"
])
self
.
_update_cfg
(
"depth"
,
cfg_dict
[
"num_hidden_layers"
])
self
.
_update_cfg
(
"num_heads"
,
cfg_dict
[
"num_attention_heads"
])
self
.
_update_cfg
(
"attn_drop_rate"
,
cfg_dict
[
"attention_probs_dropout_prob"
])
self
.
_update_cfg
(
"drop_rate"
,
cfg_dict
[
"hidden_dropout_prob"
])
# update libai_cfg by kwargs
for
k
,
v
in
self
.
kwargs
.
items
():
self
.
_update_cfg
(
k
,
v
)
self
.
_update_cfg_log
()
class
ViTLoaderLiBai
(
ModelLoaderLiBai
):
def
__init__
(
self
,
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
):
super
().
__init__
(
model
,
libai_cfg
,
pretrained_model_path
,
**
kwargs
)
self
.
base_model_prefix_2
=
""
libai/models/utils/weight_init.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
oneflow.nn
as
nn
def
init_method_normal
(
sigma
,
mean
=
0.0
):
"""Init method based on N(0, sigma)."""
def
init_
(
tensor
):
return
nn
.
init
.
normal_
(
tensor
,
mean
=
mean
,
std
=
sigma
)
return
init_
def
scaled_init_method_normal
(
sigma
,
num_layers
,
mean
=
0.0
):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
def
init_
(
tensor
):
return
nn
.
init
.
normal_
(
tensor
,
mean
=
mean
,
std
=
std
)
return
init_
libai/models/vision_transformer.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
oneflow
as
flow
import
oneflow.nn
as
nn
from
flowvision.layers.weight_init
import
trunc_normal_
import
libai.utils.distributed
as
dist
from
libai.config.config
import
configurable
from
libai.layers
import
LayerNorm
,
Linear
,
PatchEmbedding
,
TransformerLayer
class
VisionTransformer
(
nn
.
Module
):
"""Vision Transformer in LiBai.
LiBai's implementation of:
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
<https://arxiv.org/abs/2010.11929>`_
Args:
img_size (int, tuple(int)): input image size
patch_size (int, tuple(int)): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
num_classes (int): number of classes for classification head
loss_func (callable, optional): loss function for computing the total loss
between logits and labels
"""
@
configurable
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
192
,
depth
=
12
,
num_heads
=
3
,
mlp_ratio
=
4.0
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.0
,
num_classes
=
1000
,
loss_func
=
None
,
):
super
().
__init__
()
self
.
img_size
=
img_size
self
.
num_classes
=
num_classes
self
.
patch_embed
=
PatchEmbedding
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
)
ffn_size
=
int
(
embed_dim
*
mlp_ratio
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
flow
.
zeros
(
1
,
1
,
embed_dim
,
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
placement
=
dist
.
get_layer_placement
(
0
),
)
)
self
.
pos_embed
=
nn
.
Parameter
(
flow
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
,
sbp
=
dist
.
get_nd_sbp
([
flow
.
sbp
.
broadcast
,
flow
.
sbp
.
broadcast
]),
placement
=
dist
.
get_layer_placement
(
0
),
)
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
flow
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
self
.
blocks
=
nn
.
Sequential
(
*
[
TransformerLayer
(
hidden_size
=
embed_dim
,
ffn_hidden_size
=
ffn_size
,
num_attention_heads
=
num_heads
,
attention_dropout_prob
=
attn_drop_rate
,
output_dropout_prob
=
drop_rate
,
drop_path_prob
=
dpr
[
i
],
layer_idx
=
i
,
)
for
i
in
range
(
depth
)
]
)
self
.
norm
=
LayerNorm
(
embed_dim
,
layer_idx
=-
1
)
self
.
head
=
Linear
(
embed_dim
,
num_classes
,
layer_idx
=-
1
)
# loss func
self
.
loss_func
=
nn
.
CrossEntropyLoss
()
if
loss_func
is
None
else
loss_func
# weight init
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
trunc_normal_
(
self
.
cls_token
,
std
=
0.02
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
@
classmethod
def
from_config
(
cls
,
cfg
):
return
{
"img_size"
:
cfg
.
img_size
,
"patch_size"
:
cfg
.
patch_size
,
"in_chans"
:
cfg
.
in_chans
,
"embed_dim"
:
cfg
.
embed_dim
,
"depth"
:
cfg
.
depth
,
"num_heads"
:
cfg
.
num_heads
,
"mlp_ratio"
:
cfg
.
mlp_ratio
,
"drop_rate"
:
cfg
.
drop_rate
,
"attn_drop_rate"
:
cfg
.
attn_drop_rate
,
"drop_path_rate"
:
cfg
.
drop_path_rate
,
"num_classes"
:
cfg
.
num_classes
,
"loss_func"
:
cfg
.
loss_func
,
}
def
forward_features
(
self
,
x
):
# patch embedding
x
=
self
.
patch_embed
(
x
)
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
cls_token
=
cls_token
.
to_global
(
sbp
=
x
.
sbp
,
placement
=
cls_token
.
placement
)
x
=
flow
.
cat
((
cls_token
,
x
),
dim
=
1
)
# position embedding
pos_embed
=
self
.
pos_embed
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
pos_embed
=
pos_embed
.
to_global
(
sbp
=
x
.
sbp
,
placement
=
pos_embed
.
placement
)
x
=
self
.
pos_drop
(
x
+
pos_embed
)
# transformer block
x
=
self
.
blocks
(
x
)
return
x
def
forward_head
(
self
,
x
):
x
=
self
.
norm
(
x
)
outcome
=
x
[:,
0
]
outcome
=
self
.
head
(
outcome
)
return
outcome
def
forward
(
self
,
images
,
labels
=
None
):
"""
Args:
images (flow.Tensor): training samples.
labels (flow.LongTensor, optional): training targets
Returns:
dict:
A dict containing :code:`loss_value` or :code:`logits`
depending on training or evaluation mode.
:code:`{"losses": loss_value}` when training,
:code:`{"prediction_scores": logits}` when evaluating.
"""
x
=
self
.
forward_features
(
images
)
x
=
self
.
forward_head
(
x
)
if
labels
is
not
None
and
self
.
training
:
losses
=
self
.
loss_func
(
x
,
labels
)
return
{
"losses"
:
losses
}
else
:
return
{
"prediction_scores"
:
x
}
@
staticmethod
def
set_pipeline_stage_id
(
model
):
dist_utils
=
dist
.
get_dist_util
()
# Set pipeline parallelism stage_id
if
hasattr
(
model
.
pos_embed
,
"config"
):
# Old API in OneFlow 0.8
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
origin
,
PatchEmbedding
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
elif
isinstance
(
module_block
.
origin
,
TransformerLayer
):
module_block
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
# Set pos_embed and cls_token stage id
model
.
pos_embed
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
cls_token
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
pos_drop
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
norm
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
head
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
loss_func
.
config
.
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
else
:
for
module_block
in
model
.
modules
():
if
isinstance
(
module_block
.
to
(
nn
.
Module
),
PatchEmbedding
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
elif
isinstance
(
module_block
.
to
(
nn
.
Module
),
TransformerLayer
):
module_block
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
module_block
.
layer_idx
),
dist
.
get_layer_placement
(
module_block
.
layer_idx
),
)
# Set pos_embed and cls_token stage id
model
.
pos_embed
.
to
(
flow
.
nn
.
graph
.
GraphTensor
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
cls_token
.
to
(
flow
.
nn
.
graph
.
GraphTensor
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
pos_drop
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
0
),
dist
.
get_layer_placement
(
0
)
)
model
.
norm
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
head
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
model
.
loss_func
.
to
(
flow
.
nn
.
graph
.
GraphModule
).
set_stage
(
dist_utils
.
get_layer_stage_id
(
-
1
),
dist
.
get_layer_placement
(
-
1
)
)
libai/onnx_export/gpt2_to_onnx.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
oneflow
as
flow
from
oneflow
import
nn
from
oneflow_onnx.oneflow2onnx.util
import
convert_to_onnx_and_check
from
libai.config
import
LazyConfig
from
libai.models.utils
import
GPT2LoaderLiBai
from
projects.MagicPrompt.gpt2
import
GPTModel
def
get_model
(
config_file
):
cfg
=
LazyConfig
.
load
(
config_file
)
cfg
.
model
.
cfg
.
pretrained_model_path
=
None
cfg
.
dataloader
=
None
cfg
.
tokenization
=
None
print
(
"Building model...."
)
loader
=
GPT2LoaderLiBai
(
GPTModel
,
cfg
.
cfg
,
"/path/to/model"
)
model
=
loader
.
load
()
print
(
"Build model finished."
)
return
model
class
gpt2Graph
(
nn
.
Graph
):
def
__init__
(
self
,
eager_model
):
super
().
__init__
()
self
.
model
=
eager_model
def
build
(
self
,
input_ids
,
):
out
=
self
.
model
(
input_ids
,
)
return
out
if
__name__
==
"__main__"
:
model
=
get_model
(
"projects/MagicPrompt/configs/gpt2_inference.py"
)
model
.
eval
()
gpt2_graph
=
gpt2Graph
(
model
)
# Build the static graph model
input_ids
=
flow
.
ones
(
1
,
5
,
dtype
=
flow
.
int64
,
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cuda"
,
ranks
=
[
0
])
)
# check your model.forward is valid
# output = gpt2_graph(
# input_ids
# )
print
(
"Compiling the graph which may make some time, please wait for a moment...."
)
gpt2_graph
.
_compile
(
input_ids
,
)
convert_to_onnx_and_check
(
gpt2_graph
,
external_data
=
False
,
opset
=
11
,
flow_weight_dir
=
None
,
onnx_model_path
=
"./"
,
dynamic_batch_size
=
False
,
device
=
"gpu_global"
,
input_tensor_range
=
[
0
,
10
],
)
libai/onnx_export/onnx_inference/gpt2_onnx_infer.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
typing
import
List
import
numpy
as
np
import
onnxruntime
as
ort
class
OnnxModel
:
def
__init__
(
self
,
onnx_filename
,
providers
:
List
[
str
]
=
None
,
ort_optimize
:
bool
=
True
,
):
ort_sess_opt
=
ort
.
SessionOptions
()
ort_sess_opt
.
graph_optimization_level
=
(
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_EXTENDED
if
ort_optimize
else
ort
.
GraphOptimizationLevel
.
ORT_DISABLE_ALL
)
if
providers
is
None
:
if
ort
.
__version__
>
"1.9.0"
:
providers
=
[
"TensorrtExecutionProvider"
,
"CUDAExecutionProvider"
,
"CPUExecutionProvider"
,
]
else
:
providers
=
[
"CPUExecutionProvider"
]
self
.
sess
=
ort
.
InferenceSession
(
onnx_filename
,
sess_options
=
ort_sess_opt
,
providers
=
providers
)
def
forward
(
self
,
input_list
):
ipt_dict
=
OrderedDict
()
for
idx
,
ipt
in
enumerate
(
self
.
sess
.
get_inputs
()):
ipt_dict
[
ipt
.
name
]
=
input_list
[
idx
]
onnx_res
=
self
.
sess
.
run
([],
ipt_dict
)
return
onnx_res
if
__name__
==
"__main__"
:
onnx_model
=
OnnxModel
(
"model.onnx"
)
input_list
=
[
np
.
ones
((
1
,
5
)).
astype
(
np
.
int64
).
astype
(
np
.
int64
),
]
print
(
onnx_model
.
forward
(
input_list
))
libai/onnx_export/onnx_inference/t5_onnx_infer.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
from
typing
import
List
import
numpy
as
np
import
onnxruntime
as
ort
class
OnnxModel
:
def
__init__
(
self
,
onnx_filename
,
providers
:
List
[
str
]
=
None
,
ort_optimize
:
bool
=
True
,
):
ort_sess_opt
=
ort
.
SessionOptions
()
ort_sess_opt
.
graph_optimization_level
=
(
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_EXTENDED
if
ort_optimize
else
ort
.
GraphOptimizationLevel
.
ORT_DISABLE_ALL
)
if
providers
is
None
:
if
ort
.
__version__
>
"1.9.0"
:
providers
=
[
"TensorrtExecutionProvider"
,
"CUDAExecutionProvider"
,
"CPUExecutionProvider"
,
]
else
:
providers
=
[
"CPUExecutionProvider"
]
self
.
sess
=
ort
.
InferenceSession
(
onnx_filename
,
sess_options
=
ort_sess_opt
,
providers
=
providers
)
def
forward
(
self
,
input_list
):
ipt_dict
=
OrderedDict
()
for
idx
,
ipt
in
enumerate
(
self
.
sess
.
get_inputs
()):
ipt_dict
[
ipt
.
name
]
=
input_list
[
idx
]
onnx_res
=
self
.
sess
.
run
([],
ipt_dict
)
return
onnx_res
if
__name__
==
"__main__"
:
onnx_model
=
OnnxModel
(
"model.onnx"
)
input_list
=
[
np
.
ones
((
1
,
5
)).
astype
(
np
.
int64
),
np
.
ones
((
1
,
3
)).
astype
(
np
.
int64
),
np
.
ones
((
1
,
5
,
5
)).
astype
(
bool
),
np
.
ones
((
1
,
3
,
3
)).
astype
(
bool
),
np
.
ones
((
1
,
3
,
5
)).
astype
(
bool
),
]
print
(
onnx_model
.
forward
(
input_list
))
libai/onnx_export/t5_to_onnx.py
0 → 100644
View file @
9fdb7dab
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
oneflow
as
flow
from
oneflow
import
nn
from
oneflow_onnx.oneflow2onnx.util
import
convert_to_onnx_and_check
from
libai.config
import
LazyConfig
from
projects.MT5.mt5_model
import
MT5Model
from
projects.MT5.utils.mt5_loader
import
T5LoaderHuggerFace
def
get_model
(
config_file
):
cfg
=
LazyConfig
.
load
(
config_file
)
cfg
.
model
.
cfg
.
model_type
=
"mt5"
cfg
.
model
.
cfg
.
pretrained_model_path
=
None
cfg
.
dataloader
=
None
cfg
.
tokenization
=
None
print
(
"Building model...."
)
loader
=
T5LoaderHuggerFace
(
MT5Model
,
cfg
.
model
.
cfg
,
"/path/to/model"
)
model
=
loader
.
load
()
print
(
"Build model finished."
)
return
model
class
t5Graph
(
nn
.
Graph
):
def
__init__
(
self
,
eager_model
):
super
().
__init__
()
self
.
model
=
eager_model
def
build
(
self
,
encoder_input_ids
,
encoder_attn_mask
,
decoder_input_ids
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
):
out
=
self
.
model
(
encoder_input_ids
,
encoder_attn_mask
,
decoder_input_ids
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
)
return
out
if
__name__
==
"__main__"
:
model
=
get_model
(
"projects/MT5/configs/mt5_pretrain.py"
)
model
.
eval
()
t5_graph
=
t5Graph
(
model
)
# Build the static graph model
encoder_input_ids
=
flow
.
ones
(
1
,
5
,
dtype
=
flow
.
int64
,
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cuda"
,
ranks
=
[
0
])
)
encoder_attn_mask
=
flow
.
ones
(
1
,
3
,
dtype
=
flow
.
int64
,
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cuda"
,
ranks
=
[
0
])
)
decoder_input_ids
=
flow
.
ones
(
1
,
5
,
5
,
dtype
=
flow
.
bool
,
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cuda"
,
ranks
=
[
0
]),
)
decoder_attn_mask
=
flow
.
ones
(
1
,
3
,
3
,
dtype
=
flow
.
bool
,
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cuda"
,
ranks
=
[
0
]),
)
encoder_decoder_attn_mask
=
flow
.
ones
(
1
,
3
,
5
,
dtype
=
flow
.
bool
,
sbp
=
flow
.
sbp
.
broadcast
,
placement
=
flow
.
placement
(
"cuda"
,
ranks
=
[
0
]),
)
# check your model.forward is valid
# output = t5_graph(
# encoder_input_ids,
# encoder_attn_mask,
# decoder_input_ids,
# decoder_attn_mask,
# encoder_decoder_attn_mask
# )
# print(output)
print
(
"Compiling the graph which may make some time, please wait for a moment...."
)
t5_graph
.
_compile
(
encoder_input_ids
,
encoder_attn_mask
,
decoder_input_ids
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
)
convert_to_onnx_and_check
(
t5_graph
,
external_data
=
False
,
opset
=
11
,
flow_weight_dir
=
None
,
onnx_model_path
=
"./"
,
dynamic_batch_size
=
False
,
device
=
"gpu_global"
,
input_tensor_range
=
[
0
,
10
],
)
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment