Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fa2ccbc0
Commit
fa2ccbc0
authored
Dec 21, 2019
by
Aymeric Augustin
Browse files
Fix E266 flake8 warning (x90).
parent
2ab78325
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
47 additions
and
45 deletions
+47
-45
transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
...ers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
+1
-1
transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
...ormers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
+1
-1
transformers/modeling_distilbert.py
transformers/modeling_distilbert.py
+2
-2
transformers/modeling_tf_distilbert.py
transformers/modeling_tf_distilbert.py
+2
-2
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+4
-2
transformers/modeling_tf_transfo_xl.py
transformers/modeling_tf_transfo_xl.py
+10
-10
transformers/modeling_tf_xlnet.py
transformers/modeling_tf_xlnet.py
+8
-8
transformers/modeling_transfo_xl.py
transformers/modeling_transfo_xl.py
+10
-10
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+8
-8
transformers/optimization_tf.py
transformers/optimization_tf.py
+1
-1
No files found.
transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
View file @
fa2ccbc0
...
@@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
...
@@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
#
#
Required parameters
# Required parameters
parser
.
add_argument
(
parser
.
add_argument
(
"--xlm_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
"--xlm_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the official PyTorch dump."
)
)
...
...
transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py
View file @
fa2ccbc0
...
@@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
...
@@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
#
#
Required parameters
# Required parameters
parser
.
add_argument
(
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
)
...
...
transformers/modeling_distilbert.py
View file @
fa2ccbc0
...
@@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
##
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
##
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def
gelu
(
x
):
def
gelu
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
return
0.5
*
x
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
...
@@ -327,7 +327,7 @@ class Transformer(nn.Module):
...
@@ -327,7 +327,7 @@ class Transformer(nn.Module):
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
##
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
##
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class
DistilBertPreTrainedModel
(
PreTrainedModel
):
class
DistilBertPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
a simple interface for downloading and loading pretrained models.
...
...
transformers/modeling_tf_distilbert.py
View file @
fa2ccbc0
...
@@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
##
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
##
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def
gelu
(
x
):
def
gelu
(
x
):
""" Gaussian Error Linear Unit.
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
...
@@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
return
tfmr_output
# last-layer hidden-state, (all hidden_states), (all attentions)
return
tfmr_output
# last-layer hidden-state, (all hidden_states), (all attentions)
##
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
##
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class
TFDistilBertPreTrainedModel
(
TFPreTrainedModel
):
class
TFDistilBertPreTrainedModel
(
TFPreTrainedModel
):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
a simple interface for downloading and loading pretrained models.
...
...
transformers/modeling_tf_pytorch_utils.py
View file @
fa2ccbc0
...
@@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
...
@@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
#####################
#####################
### PyTorch => TF 2.0
# PyTorch => TF 2.0 #
#####################
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
...
@@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
#####################
#####################
### TF 2.0 => PyTorch
# TF 2.0 => PyTorch #
#####################
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
...
...
transformers/modeling_tf_transfo_xl.py
View file @
fa2ccbc0
...
@@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
...
@@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
def
call
(
self
,
inp
,
training
=
False
):
def
call
(
self
,
inp
,
training
=
False
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
#
####
layer normalization + positionwise feed-forward
# layer normalization + positionwise feed-forward
core_out
=
self
.
layer_norm
(
inp
)
core_out
=
self
.
layer_norm
(
inp
)
core_out
=
self
.
layer_1
(
core_out
)
core_out
=
self
.
layer_1
(
core_out
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
#
####
residual connection
# residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
#
####
positionwise feed-forward
# positionwise feed-forward
core_out
=
self
.
layer_1
(
inp
)
core_out
=
self
.
layer_1
(
inp
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
#
####
residual connection + layer normalization
# residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
...
@@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
...
@@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
r_head_k
=
tf
.
reshape
(
r_head_k
,
(
rlen
,
self
.
n_head
,
self
.
d_head
))
# qlen x n_head x d_head
r_head_k
=
tf
.
reshape
(
r_head_k
,
(
rlen
,
self
.
n_head
,
self
.
d_head
))
# qlen x n_head x d_head
#
###
compute attention score
# compute attention score
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
tf
.
einsum
(
"ibnd,jbnd->ijbn"
,
rw_head_q
,
w_head_k
)
# qlen x klen x bsz x n_head
AC
=
tf
.
einsum
(
"ibnd,jbnd->ijbn"
,
rw_head_q
,
w_head_k
)
# qlen x klen x bsz x n_head
...
@@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
...
@@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
attn_score
=
AC
+
BD
attn_score
=
AC
+
BD
attn_score
=
attn_score
*
self
.
scale
attn_score
=
attn_score
*
self
.
scale
#
###
compute attention probability
# compute attention probability
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
attn_score
=
attn_score
*
(
1
-
attn_mask_t
)
-
1e30
*
attn_mask_t
attn_score
=
attn_score
*
(
1
-
attn_mask_t
)
-
1e30
*
attn_mask_t
...
@@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
...
@@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
attn_prob
=
attn_prob
*
head_mask
#
###
compute attention vector
# compute attention vector
attn_vec
=
tf
.
einsum
(
"ijbn,jbnd->ibnd"
,
attn_prob
,
w_head_v
)
attn_vec
=
tf
.
einsum
(
"ijbn,jbnd->ibnd"
,
attn_prob
,
w_head_v
)
# [qlen x bsz x n_head x d_head]
# [qlen x bsz x n_head x d_head]
attn_vec_sizes
=
shape_list
(
attn_vec
)
attn_vec_sizes
=
shape_list
(
attn_vec
)
attn_vec
=
tf
.
reshape
(
attn_vec
,
(
attn_vec_sizes
[
0
],
attn_vec_sizes
[
1
],
self
.
n_head
*
self
.
d_head
))
attn_vec
=
tf
.
reshape
(
attn_vec
,
(
attn_vec_sizes
[
0
],
attn_vec_sizes
[
1
],
self
.
n_head
*
self
.
d_head
))
#
####
linear projection
# linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
,
training
=
training
)
attn_out
=
self
.
drop
(
attn_out
,
training
=
training
)
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
#
####
residual connection
# residual connection
outputs
=
[
w
+
attn_out
]
outputs
=
[
w
+
attn_out
]
else
:
else
:
#
####
residual connection + layer normalization
# residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
...
...
transformers/modeling_tf_xlnet.py
View file @
fa2ccbc0
...
@@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
...
@@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
(
h
,
g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
,
target_mapping
,
head_mask
)
=
inputs
(
h
,
g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
,
target_mapping
,
head_mask
)
=
inputs
if
g
is
not
None
:
if
g
is
not
None
:
#
#####
Two-stream attention with relative positional encoding.
# Two-stream attention with relative positional encoding.
# content based attention score
# content based attention score
if
mems
is
not
None
and
len
(
shape_list
(
mems
))
>
1
:
if
mems
is
not
None
and
len
(
shape_list
(
mems
))
>
1
:
cat
=
tf
.
concat
([
mems
,
h
],
axis
=
0
)
cat
=
tf
.
concat
([
mems
,
h
],
axis
=
0
)
...
@@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
...
@@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# position-based key head
# position-based key head
k_head_r
=
tf
.
einsum
(
"ibh,hnd->ibnd"
,
r
,
self
.
r
)
k_head_r
=
tf
.
einsum
(
"ibh,hnd->ibnd"
,
r
,
self
.
r
)
#
####
h-stream
# h-stream
# content-stream query head
# content-stream query head
q_head_h
=
tf
.
einsum
(
"ibh,hnd->ibnd"
,
h
,
self
.
q
)
q_head_h
=
tf
.
einsum
(
"ibh,hnd->ibnd"
,
h
,
self
.
q
)
...
@@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
...
@@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# post processing
# post processing
output_h
=
self
.
post_attention
([
h
,
attn_vec_h
],
training
=
training
)
output_h
=
self
.
post_attention
([
h
,
attn_vec_h
],
training
=
training
)
#
####
g-stream
# g-stream
# query-stream query head
# query-stream query head
q_head_g
=
tf
.
einsum
(
"ibh,hnd->ibnd"
,
g
,
self
.
q
)
q_head_g
=
tf
.
einsum
(
"ibh,hnd->ibnd"
,
g
,
self
.
q
)
...
@@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
...
@@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_prob
=
attn_prob_h
,
attn_prob_g
attn_prob
=
attn_prob_h
,
attn_prob_g
else
:
else
:
#
#####
Multi-head attention with relative positional encoding
# Multi-head attention with relative positional encoding
if
mems
is
not
None
and
len
(
shape_list
(
mems
))
>
1
:
if
mems
is
not
None
and
len
(
shape_list
(
mems
))
>
1
:
cat
=
tf
.
concat
([
mems
,
h
],
axis
=
0
)
cat
=
tf
.
concat
([
mems
,
h
],
axis
=
0
)
else
:
else
:
...
@@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
dtype_float
=
tf
.
bfloat16
if
self
.
use_bfloat16
else
tf
.
float32
dtype_float
=
tf
.
bfloat16
if
self
.
use_bfloat16
else
tf
.
float32
#
####
Attention mask
# Attention mask
# causal attention mask
# causal attention mask
if
self
.
attn_type
==
"uni"
:
if
self
.
attn_type
==
"uni"
:
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
...
@@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
#
####
Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
word_emb_k
=
inputs_embeds
word_emb_k
=
inputs_embeds
else
:
else
:
...
@@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else
:
else
:
output_g
=
None
output_g
=
None
#
####
Segment embedding
# Segment embedding
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
...
@@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else
:
else
:
seg_mat
=
None
seg_mat
=
None
#
####
Positional encoding
# Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
dtype_float
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
dtype_float
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
training
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
training
)
...
...
transformers/modeling_transfo_xl.py
View file @
fa2ccbc0
...
@@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
...
@@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
#
####
layer normalization + positionwise feed-forward
# layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
#
####
residual connection
# residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
#
####
positionwise feed-forward
# positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
core_out
=
self
.
CoreNet
(
inp
)
#
####
residual connection + layer normalization
# residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
return
output
...
@@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
...
@@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
#
###
compute attention score
# compute attention score
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
"ibnd,jbnd->ijbn"
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
AC
=
torch
.
einsum
(
"ibnd,jbnd->ijbn"
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
...
@@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
...
@@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score
=
AC
+
BD
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
attn_score
.
mul_
(
self
.
scale
)
#
###
compute attention probability
# compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
attn_mask
==
1
# Switch to bool
attn_mask
=
attn_mask
==
1
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
if
attn_mask
.
dim
()
==
2
:
...
@@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
...
@@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
attn_prob
=
attn_prob
*
head_mask
#
###
compute attention vector
# compute attention vector
attn_vec
=
torch
.
einsum
(
"ijbn,jbnd->ibnd"
,
(
attn_prob
,
w_head_v
))
attn_vec
=
torch
.
einsum
(
"ijbn,jbnd->ibnd"
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
#
####
linear projection
# linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
#
####
residual connection
# residual connection
outputs
=
[
w
+
attn_out
]
outputs
=
[
w
+
attn_out
]
else
:
else
:
#
####
residual connection + layer normalization
# residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
...
...
transformers/modeling_xlnet.py
View file @
fa2ccbc0
...
@@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
def
forward
(
self
,
h
,
g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
def
forward
(
self
,
h
,
g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
if
g
is
not
None
:
if
g
is
not
None
:
#
#####
Two-stream attention with relative positional encoding.
# Two-stream attention with relative positional encoding.
# content based attention score
# content based attention score
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
...
@@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
# position-based key head
# position-based key head
k_head_r
=
torch
.
einsum
(
"ibh,hnd->ibnd"
,
r
,
self
.
r
)
k_head_r
=
torch
.
einsum
(
"ibh,hnd->ibnd"
,
r
,
self
.
r
)
#
####
h-stream
# h-stream
# content-stream query head
# content-stream query head
q_head_h
=
torch
.
einsum
(
"ibh,hnd->ibnd"
,
h
,
self
.
q
)
q_head_h
=
torch
.
einsum
(
"ibh,hnd->ibnd"
,
h
,
self
.
q
)
...
@@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
# post processing
# post processing
output_h
=
self
.
post_attention
(
h
,
attn_vec_h
)
output_h
=
self
.
post_attention
(
h
,
attn_vec_h
)
#
####
g-stream
# g-stream
# query-stream query head
# query-stream query head
q_head_g
=
torch
.
einsum
(
"ibh,hnd->ibnd"
,
g
,
self
.
q
)
q_head_g
=
torch
.
einsum
(
"ibh,hnd->ibnd"
,
g
,
self
.
q
)
...
@@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
attn_prob
=
attn_prob_h
,
attn_prob_g
attn_prob
=
attn_prob_h
,
attn_prob_g
else
:
else
:
#
#####
Multi-head attention with relative positional encoding
# Multi-head attention with relative positional encoding
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
else
:
else
:
...
@@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
dtype_float
=
next
(
self
.
parameters
()).
dtype
dtype_float
=
next
(
self
.
parameters
()).
dtype
device
=
next
(
self
.
parameters
()).
device
device
=
next
(
self
.
parameters
()).
device
#
####
Attention mask
# Attention mask
# causal attention mask
# causal attention mask
if
self
.
attn_type
==
"uni"
:
if
self
.
attn_type
==
"uni"
:
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
...
@@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
#
####
Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
word_emb_k
=
inputs_embeds
word_emb_k
=
inputs_embeds
else
:
else
:
...
@@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
else
:
output_g
=
None
output_g
=
None
#
####
Segment embedding
# Segment embedding
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
# Convert `token_type_ids` to one-hot `seg_mat`
if
mlen
>
0
:
if
mlen
>
0
:
...
@@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
else
:
seg_mat
=
None
seg_mat
=
None
#
####
Positional encoding
# Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
pos_emb
=
self
.
dropout
(
pos_emb
)
...
...
transformers/optimization_tf.py
View file @
fa2ccbc0
...
@@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
...
@@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return
True
return
True
#
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class
GradientAccumulator
(
object
):
class
GradientAccumulator
(
object
):
"""Distribution strategies-aware gradient accumulation utility."""
"""Distribution strategies-aware gradient accumulation utility."""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment