Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
91ccbae7
Commit
91ccbae7
authored
Oct 29, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
Accepts multiple sizes
parent
c0c20883
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
81 deletions
+58
-81
transformers/modeling_albert.py
transformers/modeling_albert.py
+58
-81
No files found.
transformers/modeling_albert.py
View file @
91ccbae7
...
...
@@ -5,6 +5,7 @@ import logging
import
torch
import
torch.nn
as
nn
from
transformers.configuration_albert
import
AlbertConfig
from
transformers.modeling_bert
import
BertEmbeddings
,
BertModel
,
BertSelfAttention
,
prune_linear_layer
,
gelu_new
logger
=
logging
.
getLogger
(
__name__
)
def
load_tf_weights_in_albert
(
model
,
config
,
tf_checkpoint_path
):
...
...
@@ -32,14 +33,14 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
print
(
model
)
for
name
,
array
in
zip
(
names
,
arrays
):
print
(
name
)
og
=
name
name
=
name
.
replace
(
"transformer/group_0/inner_group_0"
,
"transformer"
)
name
=
name
.
replace
(
"LayerNorm"
,
"layer_norm"
)
name
=
name
.
replace
(
"ffn_1"
,
"ffn"
)
name
=
name
.
replace
(
"ffn/intermediate/output"
,
"ffn_output"
)
name
=
name
.
replace
(
"attention_1"
,
"attention"
)
name
=
name
.
replace
(
"cls/predictions/transform"
,
"predictions"
)
name
=
name
.
replace
(
"transformer/
l
ayer
_n
orm_1"
,
"transformer/attention/
output/
LayerNorm"
)
name
=
name
.
replace
(
"transformer/
L
ayer
N
orm_1"
,
"transformer/attention/LayerNorm"
)
name
=
name
.
split
(
'/'
)
print
(
name
)
...
...
@@ -84,44 +85,22 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
return
model
class
AlbertEmbeddings
(
nn
.
Module
):
class
AlbertEmbeddings
(
BertEmbeddings
):
"""
Construct the embeddings from word, position and token_type embeddings.
"""
def
__init__
(
self
,
config
):
super
(
AlbertEmbeddings
,
self
).
__init__
()
super
(
AlbertEmbeddings
,
self
).
__init__
(
config
)
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
embedding_size
,
padding_idx
=
0
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
embedding_size
)
self
.
token_type_embeddings
=
nn
.
Embedding
(
config
.
type_vocab_size
,
config
.
embedding_size
)
self
.
layer_norm
=
torch
.
nn
.
LayerNorm
(
config
.
embedding_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
position_ids
=
None
):
seq_length
=
input_ids
.
size
(
1
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
word_embeddings
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
word_embeddings
+
position_embeddings
+
token_type_embeddings
embeddings
=
self
.
layer_norm
(
embeddings
)
embeddings
=
self
.
dropout
(
embeddings
)
return
embeddings
self
.
LayerNorm
=
torch
.
nn
.
LayerNorm
(
config
.
embedding_size
,
eps
=
config
.
layer_norm_eps
)
def
get_word_embeddings_table
(
self
):
return
self
.
word_embeddings
class
AlbertModel
(
nn
.
Module
):
class
AlbertModel
(
BertModel
):
def
__init__
(
self
,
config
):
super
(
AlbertModel
,
self
).
__init__
()
super
(
AlbertModel
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
embeddings
=
AlbertEmbeddings
(
config
)
...
...
@@ -129,6 +108,7 @@ class AlbertModel(nn.Module):
self
.
pooler
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
pooler_activation
=
nn
.
Tanh
()
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
...
@@ -166,7 +146,7 @@ class AlbertForMaskedLM(nn.Module):
self
.
config
=
config
self
.
bert
=
AlbertModel
(
config
)
self
.
l
ayer
_n
orm
=
nn
.
LayerNorm
(
config
.
embedding_size
)
self
.
L
ayer
N
orm
=
nn
.
LayerNorm
(
config
.
embedding_size
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
config
.
vocab_size
))
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
embedding_size
)
self
.
word_embeddings
=
nn
.
Linear
(
config
.
embedding_size
,
config
.
vocab_size
)
...
...
@@ -182,39 +162,47 @@ class AlbertForMaskedLM(nn.Module):
hidden_states
=
self
.
bert
(
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
)[
0
]
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
gelu_new
(
hidden_states
)
hidden_states
=
self
.
l
ayer
_n
orm
(
hidden_states
)
hidden_states
=
self
.
L
ayer
N
orm
(
hidden_states
)
logits
=
self
.
word_embeddings
(
hidden_states
)
return
logits
class
AlbertAttention
(
nn
.
Module
):
class
AlbertAttention
(
BertSelfAttention
):
def
__init__
(
self
,
config
):
super
(
AlbertAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
output_attentions
=
config
.
output_attentions
super
(
AlbertAttention
,
self
).
__init__
(
config
)
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
query
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
all_head_size
)
self
.
key
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
all_head_size
)
self
.
value
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
all_head_size
)
self
.
hidden_size
=
config
.
hidden_size
self
.
attention_head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
transpose_for_scores
(
self
,
x
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
x
=
x
.
view
(
*
new_x_shape
)
return
x
.
permute
(
0
,
2
,
1
,
3
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and emove already pruned heads
for
head
in
heads
:
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
query
=
prune_linear_layer
(
self
.
query
,
index
)
self
.
key
=
prune_linear_layer
(
self
.
key
,
index
)
self
.
value
=
prune_linear_layer
(
self
.
value
,
index
)
self
.
output
.
dense
=
prune_linear_layer
(
self
.
output
.
dense
,
index
,
dim
=
1
)
# Update hyper params and store pruned heads
self
.
num_attention_heads
=
self
.
num_attention_heads
-
len
(
heads
)
self
.
all_head_size
=
self
.
attention_head_size
*
self
.
num_attention_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
head_mask
=
None
):
mixed_query_layer
=
self
.
query
(
input_ids
)
...
...
@@ -248,7 +236,8 @@ class AlbertAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
reshaped_context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
w
=
self
.
dense
.
weight
.
T
.
view
(
16
,
64
,
1024
)
print
(
self
.
dense
.
weight
.
T
.
shape
)
w
=
self
.
dense
.
weight
.
T
.
view
(
self
.
num_attention_heads
,
self
.
attention_head_size
,
self
.
hidden_size
)
b
=
self
.
dense
.
bias
projected_context_layer
=
torch
.
einsum
(
"bfnd,ndh->bfh"
,
context_layer
,
w
)
+
b
...
...
@@ -262,7 +251,7 @@ class AlbertTransformer(nn.Module):
super
(
AlbertTransformer
,
self
).
__init__
()
self
.
config
=
config
self
.
l
ayer
_n
orm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
L
ayer
N
orm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
AlbertAttention
(
config
)
self
.
ffn
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
ffn_output
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
...
...
@@ -273,18 +262,11 @@ class AlbertTransformer(nn.Module):
ffn_output
=
self
.
ffn
(
attention_output
)
ffn_output
=
gelu_new
(
ffn_output
)
ffn_output
=
self
.
ffn_output
(
ffn_output
)
hidden_states
=
self
.
l
ayer
_n
orm
(
ffn_output
+
attention_output
)
hidden_states
=
self
.
L
ayer
N
orm
(
ffn_output
+
attention_output
)
return
hidden_states
def
gelu_new
(
x
):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
class
AlbertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
AlbertEncoder
,
self
).
__init__
()
...
...
@@ -305,27 +287,22 @@ class AlbertEncoder(nn.Module):
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
# config = AlbertConfig.from_json_file("config.json")
# # model = AlbertForMaskedLM(config)
# model = AlbertModel(config)
# model = load_tf_weights_in_albert(model, config, "albert/albert")
# print(model)
# input_ids = torch.tensor([[31, 51, 99], [15, 5, 0]])
# input_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
# segment_ids = torch.tensor([[0, 0, 1], [0, 0, 0]])
# # sequence_output, pooled_outputs = model()
model_size
=
"base"
config
=
AlbertConfig
.
from_json_file
(
"/home/hf/google-research/albert/config_{}.json"
.
format
(
model_size
))
model
=
AlbertModel
(
config
)
model
=
load_tf_weights_in_albert
(
model
,
config
,
"/home/hf/transformers/albert-{}/albert-{}"
.
format
(
model_size
,
model_size
))
model
.
eval
()
print
(
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
))
# logits = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)[1]
input_ids
=
[[
31
,
51
,
99
,
88
,
54
,
34
,
23
,
23
,
12
],
[
15
,
5
,
0
,
88
,
54
,
34
,
23
,
23
,
12
]]
input_mask
=
[[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
]]
segment_ids
=
[[
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]
# embeddings_output =
# print("pooled output", logits
)
# # print("Pooled output", pooled_output
s)
pt_input_ids
=
torch
.
tensor
(
input_ids
)
pt_input_mask
=
torch
.
tensor
(
input_mask
)
pt_segment_ids
=
torch
.
tensor
(
segment_id
s
)
config
=
AlbertConfig
.
from_json_file
(
"/home/hf/google-research/albert/config.json"
)
model
=
AlbertModel
(
config
)
model
=
load_tf_weights_in_albert
(
model
,
config
,
"/home/hf/transformers/albert/albert"
)
\ No newline at end of file
pt_dict
=
{
"input_ids"
:
pt_input_ids
,
"attention_mask"
:
pt_input_mask
,
"token_type_ids"
:
pt_segment_ids
}
pt_output
=
model
(
**
pt_dict
)
print
(
pt_output
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment