Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
58918c76
Unverified
Commit
58918c76
authored
Jun 23, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 23, 2020
Browse files
[bart] add config.extra_pos_embeddings to facilitate reuse (#5190)
parent
b28b5371
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
32 deletions
+35
-32
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+4
-0
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+16
-15
src/transformers/modeling_roberta.py
src/transformers/modeling_roberta.py
+14
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+0
-14
tests/test_modeling_roberta.py
tests/test_modeling_roberta.py
+1
-2
No files found.
src/transformers/configuration_bart.py
View file @
58918c76
...
...
@@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig):
def
__init__
(
self
,
activation_dropout
=
0.0
,
extra_pos_embeddings
=
2
,
activation_function
=
"gelu"
,
vocab_size
=
50265
,
d_model
=
1024
,
...
...
@@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig):
# Classifier stuff
self
.
classif_dropout
=
classifier_dropout
# pos embedding offset
self
.
extra_pos_embeddings
=
self
.
pad_token_id
+
1
@
property
def
num_attention_heads
(
self
)
->
int
:
return
self
.
encoder_attention_heads
...
...
src/transformers/modeling_bart.py
View file @
58918c76
...
...
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from
.activations
import
ACT2FN
from
.configuration_bart
import
BartConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
PreTrainedModel
,
create_position_ids_from_input_ids
from
.modeling_utils
import
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r"""
def
invert_mask
(
attention_mask
):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert
attention_mask
.
dim
()
==
2
return
attention_mask
.
eq
(
0
)
...
...
@@ -261,7 +262,7 @@ class BartEncoder(nn.Module):
)
else
:
self
.
embed_positions
=
LearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
self
.
padding_idx
,
config
.
max_position_embeddings
,
embed_dim
,
self
.
padding_idx
,
config
.
extra_pos_embeddings
,
)
self
.
layers
=
nn
.
ModuleList
([
EncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
...
...
@@ -435,7 +436,7 @@ class BartDecoder(nn.Module):
)
else
:
self
.
embed_positions
=
LearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
self
.
padding_idx
,
config
.
max_position_embeddings
,
config
.
d_model
,
self
.
padding_idx
,
config
.
extra_pos_embeddings
,
)
self
.
layers
=
nn
.
ModuleList
(
[
DecoderLayer
(
config
)
for
_
in
range
(
config
.
decoder_layers
)]
...
...
@@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding):
position ids are passed to the forward function.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
):
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
offset
):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self
.
offset
=
offset
assert
padding_idx
is
not
None
num_embeddings
+=
padding_idx
+
1
# WHY?
num_embeddings
+=
offset
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
def
forward
(
self
,
input
,
use_cache
=
False
):
def
forward
(
self
,
input
_ids
,
use_cache
=
False
):
"""Input is expected to be of size [bsz x seqlen]."""
if
use_cache
:
# the position is our current step in the decoded sequence
pos
=
int
(
self
.
padding_idx
+
input
.
size
(
1
))
positions
=
input
.
data
.
new
(
1
,
1
).
fill_
(
pos
)
bsz
,
seq_len
=
input_ids
.
shape
[:
2
]
if
use_cache
:
positions
=
input
_ids
.
data
.
new
(
1
,
1
).
fill_
(
seq_len
-
1
)
# called before slicing
else
:
positions
=
create_position_ids_from_input_ids
(
input
,
self
.
padding_idx
)
return
super
().
forward
(
positions
)
# starts at 0, ends at 1-seq_len
positions
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
weight
.
device
)
return
super
().
forward
(
positions
+
self
.
offset
)
def
LayerNorm
(
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
...
...
src/transformers/modeling_roberta.py
View file @
58918c76
...
...
@@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss
from
.configuration_roberta
import
RobertaConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_bert
import
BertEmbeddings
,
BertLayerNorm
,
BertModel
,
BertPreTrainedModel
,
gelu
from
.modeling_utils
import
create_position_ids_from_input_ids
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask
=
input_ids
.
ne
(
padding_idx
).
int
()
incremental_indices
=
torch
.
cumsum
(
mask
,
dim
=
1
).
type_as
(
mask
)
*
mask
return
incremental_indices
.
long
()
+
padding_idx
src/transformers/modeling_utils.py
View file @
58918c76
...
...
@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module):
return
output
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask
=
input_ids
.
ne
(
padding_idx
).
int
()
incremental_indices
=
torch
.
cumsum
(
mask
,
dim
=
1
).
type_as
(
mask
)
*
mask
return
incremental_indices
.
long
()
+
padding_idx
def
prune_linear_layer
(
layer
,
index
,
dim
=
0
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
...
...
tests/test_modeling_roberta.py
View file @
58918c76
...
...
@@ -34,9 +34,8 @@ if is_torch_available():
RobertaForSequenceClassification
,
RobertaForTokenClassification
,
)
from
transformers.modeling_roberta
import
RobertaEmbeddings
from
transformers.modeling_roberta
import
RobertaEmbeddings
,
create_position_ids_from_input_ids
from
transformers.modeling_roberta
import
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
from
transformers.modeling_utils
import
create_position_ids_from_input_ids
class
RobertaModelTester
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment