Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a84adddd
Commit
a84adddd
authored
Sep 12, 2019
by
thomwolf
Browse files
convert all models
parent
969d3ae9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1184 additions
and
27 deletions
+1184
-27
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+76
-27
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+1108
-0
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
a84adddd
...
@@ -18,10 +18,11 @@ from __future__ import absolute_import
...
@@ -18,10 +18,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
argparse
import
argparse
import
tensorflow
as
tf
import
tensorflow
as
tf
from
pytorch_transformers
import
is_torch_available
from
pytorch_transformers
import
is_torch_available
,
cached_path
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
...
@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
...
@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pytorch_transformers
import
BertForPreTraining
,
GPT2LMHeadModel
,
XLNetLMHeadModel
,
XLMWithLMHeadModel
from
pytorch_transformers
import
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
else
:
else
:
BertForPreTraining
,
GPT2LMHeadModel
=
None
,
None
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
import
logging
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
),
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
}
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
if
model_type
not
in
MODEL_CLASSES
:
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
# Initialise TF model
# Initialise TF model
config
=
config_class
.
from_json_file
(
config_file
)
config
=
config_class
.
from_json_file
(
config_file
)
...
@@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
pt_model
=
pt_model_class
.
from_pretrained
(
None
,
pt_model
=
pt_model_class
.
from_pretrained
(
None
,
config
=
config
,
config
=
config
,
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
))
map_location
=
'cpu'
))
pt_inputs
=
torch
.
tensor
(
inputs_list
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
np_tf
=
tfo
[
0
].
numpy
()
np_tf
=
tfo
[
0
].
numpy
()
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
assert
diff
<=
1e-3
,
"Error, model absolute difference is >1e-3"
# Save pytorch-model
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
tf_model
.
save_weights
(
tf_dump_path
)
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
'h5'
)
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
compare_with_pt_model
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
model_types
=
list
(
MODEL_CLASSES
.
keys
())
else
:
model_types
=
[
args_model_type
]
for
j
,
model_type
in
enumerate
(
model_types
,
start
=
1
):
print
(
"="
*
100
)
print
(
" Converting model type {}/{}: {}"
.
format
(
j
,
len
(
model_types
),
model_type
))
print
(
"="
*
100
)
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
for
i
,
shortcut_name
in
enumerate
(
aws_config_map
.
keys
(),
start
=
1
):
print
(
"-"
*
100
)
print
(
" Converting checkpoint {}/{}: {}"
.
format
(
i
,
len
(
aws_config_map
),
shortcut_name
))
print
(
"-"
*
100
)
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
True
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
True
)
convert_pt_checkpoint_to_tf
(
model_type
,
model_file
,
config_file
,
os
.
path
.
join
(
tf_dump_path
,
shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
## Required parameters
## Required parameters
parser
.
add_argument
(
"--
model_type
"
,
parser
.
add_argument
(
"--
tf_dump_path
"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Model type selcted in the list of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
help
=
"Model type selected in the list of {}. If not given, will download and convert all the models from AWS."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
help
=
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
help
=
"Path to the PyTorch checkpoint path
."
)
"If not given, will download and convert all the checkpoints from AWS
."
)
parser
.
add_argument
(
"--config_file"
,
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained model.
\n
"
help
=
"The config json file corresponding to the pre-trained model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture. If not given and "
parser
.
add_argument
(
"--tf_dump_path"
,
"--pytorch_checkpoint_path is not given or is a shortcut name"
default
=
None
,
"use the configuration associated to teh shortcut name on the AWS"
)
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--compare_with_pt_model"
,
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
help
=
"Compare Tensorflow and PyTorch model predictions."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_pt_checkpoint_to_tf
(
args
.
model_type
.
lower
(),
args
.
pytorch_checkpoint_path
,
if
args
.
pytorch_checkpoint_path
is
not
None
:
args
.
config_file
,
convert_pt_checkpoint_to_tf
(
args
.
model_type
.
lower
(),
args
.
tf_dump_path
,
args
.
pytorch_checkpoint_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
args
.
config_file
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
else
:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
pytorch_transformers/modeling_tf_transfo_xl.py
0 → 100644
View file @
a84adddd
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 Transformer XL model.
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
json
import
math
import
logging
import
collections
import
sys
from
io
import
open
import
numpy
as
np
import
tensorflow
as
tf
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5"
,
}
def
load_transfo_xl_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
PositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
demb
):
super
(
PositionalEmbedding
,
self
).
__init__
()
self
.
demb
=
demb
inv_freq
=
1
/
(
10000
**
(
torch
.
arange
(
0.0
,
demb
,
2.0
)
/
demb
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
pos_seq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
ger
(
pos_seq
,
self
.
inv_freq
)
pos_emb
=
torch
.
cat
([
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()],
dim
=-
1
)
if
bsz
is
not
None
:
return
pos_emb
[:,
None
,:].
expand
(
-
1
,
bsz
,
-
1
)
else
:
return
pos_emb
[:,
None
,:]
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
else
:
c
=
h
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
h
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
h
.
size
(
0
),
h
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
h
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
h
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
qkv_net
=
nn
.
Linear
(
d_model
,
3
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
_parallelogram_mask
(
self
,
h
,
w
,
left
=
False
):
mask
=
torch
.
ones
((
h
,
w
)).
byte
()
m
=
min
(
h
,
w
)
mask
[:
m
,:
m
]
=
torch
.
triu
(
mask
[:
m
,:
m
])
mask
[
-
m
:,
-
m
:]
=
torch
.
tril
(
mask
[
-
m
:,
-
m
:])
if
left
:
return
mask
else
:
return
mask
.
flip
(
0
)
def
_shift
(
self
,
x
,
qlen
,
klen
,
mask
,
left
=
False
):
if
qlen
>
1
:
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
qlen
-
1
,
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
zero_pad
=
torch
.
zeros
(
0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
left
:
mask
=
mask
.
flip
(
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
else
:
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
x
=
x_padded
.
masked_select
(
mask
[:,:,
None
,
None
])
\
.
view
(
qlen
,
klen
,
x
.
size
(
2
),
x
.
size
(
3
))
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
zero_pad_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded_shape
=
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
))
+
x
.
size
()[
2
:]
x_padded
=
x_padded
.
view
(
*
x_padded_shape
)
x
=
x_padded
[
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
0
),
x
.
size
(
1
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
1
)
-
x
.
size
(
0
))[:,:,
None
,
None
]
return
x
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
raise
NotImplementedError
class
RelPartialLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
#### compute attention score
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
rr_head_q
=
w_head_q
+
self
.
r_r_bias
BD
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
rr_head_q
,
r_head_k
))
# qlen x klen x bsz x n_head
BD
=
self
.
_rel_shift
(
BD
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[
None
,:,:,
None
],
-
1e30
).
type_as
(
attn_score
)
elif
attn_mask
.
dim
()
==
3
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[:,:,:,
None
],
-
1e30
).
type_as
(
attn_score
)
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen
,
bsz
=
w
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
if
klen
>
r_emb
.
size
(
0
):
r_emb_pad
=
r_emb
[
0
:
1
].
expand
(
klen
-
r_emb
.
size
(
0
),
-
1
,
-
1
)
r_emb
=
torch
.
cat
([
r_emb_pad
,
r_emb
],
0
)
r_bias_pad
=
r_bias
[
0
:
1
].
expand
(
klen
-
r_bias
.
size
(
0
),
-
1
)
r_bias
=
torch
.
cat
([
r_bias_pad
,
r_bias
],
0
)
else
:
r_emb
=
r_emb
[
-
klen
:]
r_bias
=
r_bias
[
-
klen
:]
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
[
None
]
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
B_
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
w_head_q
,
r_emb
))
# qlen x klen x bsz x n_head
D_
=
r_bias
[
None
,
:,
None
]
# 1 x klen x 1 x n_head
BD
=
self
.
_rel_shift
(
B_
+
D_
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelPartialLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
AdaptiveEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
super
(
AdaptiveEmbedding
,
self
).
__init__
()
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
cutoffs
=
cutoffs
+
[
n_token
]
self
.
div_val
=
div_val
self
.
d_proj
=
d_proj
self
.
emb_scale
=
d_proj
**
0.5
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_projs
=
nn
.
ParameterList
()
if
div_val
==
1
:
self
.
emb_layers
.
append
(
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
)
if
d_proj
!=
d_embed
:
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
FloatTensor
(
d_proj
,
d_embed
)))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
FloatTensor
(
d_proj
,
d_emb_i
)))
def
forward
(
self
,
inp
):
if
self
.
div_val
==
1
:
embed
=
self
.
emb_layers
[
0
](
inp
)
if
self
.
d_proj
!=
self
.
d_embed
:
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
])
else
:
param
=
next
(
self
.
parameters
())
inp_flat
=
inp
.
view
(
-
1
)
emb_flat
=
torch
.
zeros
([
inp_flat
.
size
(
0
),
self
.
d_proj
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
mask_i
=
(
inp_flat
>=
l_idx
)
&
(
inp_flat
<
r_idx
)
indices_i
=
mask_i
.
nonzero
().
squeeze
()
if
indices_i
.
numel
()
==
0
:
continue
inp_i
=
inp_flat
.
index_select
(
0
,
indices_i
)
-
l_idx
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
])
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
embed_shape
=
inp
.
size
()
+
(
self
.
d_proj
,)
embed
=
emb_flat
.
view
(
embed_shape
)
embed
.
mul_
(
self
.
emb_scale
)
return
embed
class
TransfoXLPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
def
_init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
_init_weights
(
self
,
m
):
""" Initialize the weights.
"""
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
_init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
if
m
.
emb_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
_init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
_init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
_init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
out_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'LayerNorm'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
else
:
if
hasattr
(
m
,
'r_emb'
):
self
.
_init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
_init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
_init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
_init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputing raw hidden-states without any specific head on top."
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
n_token
=
config
.
n_token
self
.
d_embed
=
config
.
d_embed
self
.
d_model
=
config
.
d_model
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
self
.
word_emb
=
AdaptiveEmbedding
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
drop
=
nn
.
Dropout
(
config
.
dropout
)
self
.
n_layer
=
config
.
n_layer
self
.
tgt_len
=
config
.
tgt_len
self
.
mem_len
=
config
.
mem_len
self
.
ext_len
=
config
.
ext_len
self
.
max_klen
=
config
.
tgt_len
+
config
.
ext_len
+
config
.
mem_len
self
.
attn_type
=
config
.
attn_type
if
not
config
.
untie_r
:
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
layers
=
nn
.
ModuleList
()
if
config
.
attn_type
==
0
:
# the default attention
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
RelPartialLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
config
.
clamp_len
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
1
:
# learnable
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
))
elif
self
.
attn_type
==
2
:
# absolute standard
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
init_weights
()
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
return
self
.
word_emb
def
backward_compatible
(
self
):
self
.
sample_softmax
=
-
1
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
tgt_len
=
tgt_len
self
.
mem_len
=
mem_len
self
.
ext_len
=
ext_len
def
_prune_heads
(
self
,
heads
):
logger
.
info
(
"Head pruning is not implemented for Transformer-XL model"
)
pass
def
init_mems
(
self
,
data
):
if
self
.
mem_len
>
0
:
mems
=
[]
param
=
next
(
self
.
parameters
())
for
i
in
range
(
self
.
n_layer
):
empty
=
torch
.
zeros
(
self
.
mem_len
,
data
.
size
(
1
),
self
.
config
.
d_model
,
dtype
=
param
.
dtype
,
device
=
param
.
device
)
mems
.
append
(
empty
)
return
mems
else
:
return
None
def
_update_mems
(
self
,
hids
,
mems
,
qlen
,
mlen
):
# does not deal with None
if
mems
is
None
:
return
None
# mems is not None
assert
len
(
hids
)
==
len
(
mems
),
'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with
torch
.
no_grad
():
new_mems
=
[]
end_idx
=
mlen
+
max
(
0
,
qlen
-
0
-
self
.
ext_len
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
for
i
in
range
(
len
(
hids
)):
cat
=
torch
.
cat
([
mems
[
i
],
hids
[
i
]],
dim
=
0
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
].
detach
())
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
if
self
.
same_length
:
all_ones
=
word_emb
.
new_ones
((
qlen
,
klen
),
dtype
=
torch
.
uint8
)
mask_len
=
klen
-
self
.
mem_len
if
mask_len
>
0
:
mask_shift_len
=
qlen
-
mask_len
else
:
mask_shift_len
=
qlen
dec_attn_mask
=
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
))[:,
:,
None
]
# -1
else
:
dec_attn_mask
=
torch
.
triu
(
word_emb
.
new_ones
((
qlen
,
klen
),
dtype
=
torch
.
uint8
),
diagonal
=
1
+
mlen
)[:,:,
None
]
hids
=
[]
attentions
=
[]
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
)
pos_emb
=
self
.
drop
(
pos_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
if
self
.
clamp_len
>
0
:
r_emb
=
self
.
r_emb
[
i
][
-
self
.
clamp_len
:]
r_bias
=
self
.
r_bias
[
i
][
-
self
.
clamp_len
:]
else
:
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
+
pos_emb
[
-
qlen
:])
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
mlen
>
0
:
cur_emb
=
self
.
r_emb
[
i
][:
-
qlen
]
cur_size
=
cur_emb
.
size
(
0
)
if
cur_size
<
mlen
:
cur_emb_pad
=
cur_emb
[
0
:
1
].
expand
(
mlen
-
cur_size
,
-
1
,
-
1
)
cur_emb
=
torch
.
cat
([
cur_emb_pad
,
cur_emb
],
0
)
else
:
cur_emb
=
cur_emb
[
-
mlen
:]
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
core_out
=
self
.
drop
(
core_out
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
# We transpose back here to shape [bsz, len, hidden_dim]
outputs
=
[
core_out
.
transpose
(
0
,
1
).
contiguous
(),
new_mems
]
if
self
.
output_hidden_states
:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids
.
append
(
core_out
)
hids
=
list
(
t
.
transpose
(
0
,
1
).
contiguous
()
for
t
in
hids
)
outputs
.
append
(
hids
)
if
self
.
output_attentions
:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
outputs
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
@
add_start_docstrings
(
"""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)"""
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
r
"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
We don't output them when the loss is computed to speedup adaptive softmax decoding.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
prediction_scores, mems = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
TransfoXLModel
(
config
)
self
.
sample_softmax
=
config
.
sample_softmax
# use sampled softmax
if
config
.
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
)
self
.
sampler
=
LogUniformSampler
(
config
.
n_token
,
config
.
sample_softmax
)
# use adaptive softmax (including standard softmax)
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
"""
Run this to be sure output and input (adaptive) softmax weights are tied
"""
# sampled softmax
if
self
.
sample_softmax
>
0
:
if
self
.
config
.
tie_weight
:
self
.
out_layer
.
weight
=
self
.
transformer
.
word_emb
.
weight
# adaptive softmax (including standard softmax)
else
:
if
self
.
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
_tie_or_clone_weights
(
self
.
crit
.
out_layers
[
i
],
self
.
transformer
.
word_emb
.
emb_layers
[
i
])
if
self
.
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
self
.
config
.
tie_projs
):
if
tie_proj
and
self
.
config
.
div_val
==
1
and
self
.
config
.
d_model
!=
self
.
config
.
d_embed
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
0
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
self
.
config
.
div_val
!=
1
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
i
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
i
]
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
,
labels
=
None
):
bsz
=
input_ids
.
size
(
0
)
tgt_len
=
input_ids
.
size
(
1
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
outputs
=
transformer_outputs
[
1
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
config
.
tie_weight
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
labels
,
pred_hid
,
self
.
sampler
)
softmax_output
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
outputs
=
[
softmax_output
]
+
outputs
if
labels
is
not
None
:
# TODO: This is not implemented
raise
NotImplementedError
else
:
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
labels
)
if
labels
is
None
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
,
-
1
)
outputs
=
[
softmax_output
]
+
outputs
else
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
)
outputs
=
[
softmax_output
,
None
]
+
outputs
return
outputs
# (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment