Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a332cc9f
"vscode:/vscode.git/clone" did not exist on "5cd9e2cba13afb1e5d00401f1ebc1dc733070d46"
Commit
a332cc9f
authored
Mar 11, 2020
by
Patrick von Platen
Browse files
finalize generation merge
parent
1ba21f96
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
13 deletions
+10
-13
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+3
-3
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+3
-3
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+0
-3
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+4
-4
No files found.
src/transformers/configuration_bart.py
View file @
a332cc9f
...
...
@@ -40,8 +40,9 @@ class BartConfig(PretrainedConfig):
self
,
activation_dropout
=
0.0
,
vocab_size
=
50265
,
bos_token_id
=
0
,
pad_token_id
=
1
,
eos_token_id
=
2
,
eos_token_id
s
=
[
2
]
,
d_model
=
1024
,
encoder_ffn_dim
=
4096
,
encoder_layers
=
12
,
...
...
@@ -58,7 +59,6 @@ class BartConfig(PretrainedConfig):
classifier_dropout
=
0.0
,
output_past
=
False
,
num_labels
=
3
,
bos_token_id
=
0
,
is_encoder_decoder
=
True
,
**
common_kwargs
):
...
...
@@ -73,12 +73,12 @@ class BartConfig(PretrainedConfig):
output_past
=
output_past
,
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_ids
=
eos_token_ids
,
is_encoder_decoder
=
is_encoder_decoder
,
**
common_kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
d_model
=
d_model
# encoder_embed_dim and decoder_embed_dim
self
.
eos_token_id
=
eos_token_id
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
encoder_layers
=
self
.
num_hidden_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
...
...
src/transformers/modeling_bart.py
View file @
a332cc9f
...
...
@@ -962,8 +962,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def
prepare_scores_for_generation
(
self
,
scores
,
cur_len
,
max_length
):
if
cur_len
==
1
:
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
bos_token_id
)
if
cur_len
==
max_length
-
1
:
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
eos_token_ids
)
if
cur_len
==
max_length
-
1
and
self
.
config
.
eos_token_ids
[
0
]
is
not
None
:
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
eos_token_ids
[
0
]
)
return
scores
@
staticmethod
...
...
@@ -1056,7 +1056,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs
=
encoder_outputs
,
)
x
=
outputs
[
0
]
# last hidden state
eos_mask
=
input_ids
.
eq
(
self
.
config
.
eos_token_id
)
eos_mask
=
input_ids
.
eq
(
self
.
config
.
eos_token_id
s
[
0
]
)
if
len
(
torch
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
sentence_representation
=
x
[
eos_mask
,
:].
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))[:,
-
1
,
:]
...
...
src/transformers/modeling_utils.py
View file @
a332cc9f
...
...
@@ -840,14 +840,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
self
.
config
.
is_encoder_decoder
:
eos_token_id
=
eos_token_ids
[
0
]
assert
bos_token_id
is
not
None
,
"Encoder Decoder Models need to have a bos_token_id"
assert
eos_token_id
is
not
None
,
"Encoder Decoder Models need to have a eos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
encoder_inputs
=
input_ids
input_ids
=
torch
.
full
(
(
effective_batch_size
*
num_beams
,
1
),
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
,
...
...
tests/test_modeling_bart.py
View file @
a332cc9f
...
...
@@ -82,7 +82,7 @@ class ModelTester:
dropout
=
self
.
hidden_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
eos_token_ids
=
self
.
eos_token_id
,
eos_token_ids
=
[
self
.
eos_token_id
]
,
bos_token_id
=
self
.
bos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
)
...
...
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim
=
32
,
max_position_embeddings
=
48
,
output_past
=
output_past
,
eos_token_id
=
2
,
eos_token_id
s
=
[
2
]
,
pad_token_id
=
1
,
bos_token_id
=
0
,
)
...
...
@@ -276,7 +276,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim
=
32
,
max_position_embeddings
=
48
,
output_past
=
True
,
eos_token_ids
=
2
,
eos_token_ids
=
[
2
]
,
pad_token_id
=
1
,
bos_token_id
=
0
,
)
...
...
@@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase):
new_input_ids
=
lm_model
.
generate
(
input_ids
.
clone
(),
num_return_sequences
=
1
,
num_beams
=
2
,
no_repeat_ngram_size
=
3
,
max_length
=
max_length
)
self
.
assertEqual
(
new_input_ids
.
shape
,
(
input_ids
.
shape
[
0
],
max_length
))
self
.
assertEqual
(
new_input_ids
.
shape
,
(
input_ids
.
shape
[
0
],
max_length
-
1
))
# TODO(SS): uneven length batches, empty inputs
def
test_shift_tokens_right
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment