Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5c5af879
"vscode:/vscode.git/clone" did not exist on "c11160114a155de38c072bfa56eab10e938ca5b7"
Unverified
Commit
5c5af879
authored
Mar 03, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 03, 2020
Browse files
[Bart] dont call .forward (#3094)
parent
a088d75e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+11
-11
No files found.
src/transformers/modeling_bart.py
View file @
5c5af879
...
@@ -208,7 +208,7 @@ class EncoderLayer(nn.Module):
...
@@ -208,7 +208,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
encoded output of shape `(seq_len, batch, embed_dim)`
"""
"""
residual
=
x
residual
=
x
x
,
attn_weights
=
self
.
self_attn
.
forward
(
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
self
.
output_attentions
,
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
self
.
output_attentions
,
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -292,7 +292,7 @@ class BartEncoder(nn.Module):
...
@@ -292,7 +292,7 @@ class BartEncoder(nn.Module):
if
self
.
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
if
self
.
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
attn
=
None
attn
=
None
else
:
else
:
x
,
attn
=
encoder_layer
.
forward
(
x
,
attention_mask
)
x
,
attn
=
encoder_layer
(
x
,
attention_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
all_attentions
.
append
(
attn
)
all_attentions
.
append
(
attn
)
...
@@ -356,7 +356,7 @@ class DecoderLayer(nn.Module):
...
@@ -356,7 +356,7 @@ class DecoderLayer(nn.Module):
if
layer_state
is
None
:
if
layer_state
is
None
:
layer_state
=
{}
layer_state
=
{}
# next line mutates layer state
# next line mutates layer state
x
,
self_attn_weights
=
self
.
self_attn
.
forward
(
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
layer_state
=
layer_state
,
need_weights
=
need_attn_weights
,
attn_mask
=
attention_mask
,
query
=
x
,
key
=
y
,
value
=
y
,
layer_state
=
layer_state
,
need_weights
=
need_attn_weights
,
attn_mask
=
attention_mask
,
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -365,7 +365,7 @@ class DecoderLayer(nn.Module):
...
@@ -365,7 +365,7 @@ class DecoderLayer(nn.Module):
residual
=
x
residual
=
x
assert
self
.
encoder_attn
.
cache_key
!=
self
.
self_attn
.
cache_key
assert
self
.
encoder_attn
.
cache_key
!=
self
.
self_attn
.
cache_key
x
,
encoder_attn_weights
=
self
.
encoder_attn
.
forward
(
x
,
encoder_attn_weights
=
self
.
encoder_attn
(
query
=
x
,
query
=
x
,
key
=
encoder_hidden_states
,
# could be None
key
=
encoder_hidden_states
,
# could be None
value
=
encoder_hidden_states
,
value
=
encoder_hidden_states
,
...
@@ -449,7 +449,7 @@ class BartDecoder(nn.Module):
...
@@ -449,7 +449,7 @@ class BartDecoder(nn.Module):
- attentions
- attentions
"""
"""
# embed positions
# embed positions
positions
=
self
.
embed_positions
.
forward
(
input_ids
,
generation_mode
=
self
.
generation_mode
)
positions
=
self
.
embed_positions
(
input_ids
,
generation_mode
=
self
.
generation_mode
)
if
self
.
generation_mode
:
if
self
.
generation_mode
:
input_ids
=
input_ids
[:,
-
1
:]
input_ids
=
input_ids
[:,
-
1
:]
...
@@ -475,7 +475,7 @@ class BartDecoder(nn.Module):
...
@@ -475,7 +475,7 @@ class BartDecoder(nn.Module):
continue
continue
layer_state
=
decoder_cached_states
[
i
]
if
decoder_cached_states
is
not
None
else
None
layer_state
=
decoder_cached_states
[
i
]
if
decoder_cached_states
is
not
None
else
None
x
,
layer_self_attn
,
layer_past
=
decoder_layer
.
forward
(
x
,
layer_self_attn
,
layer_past
=
decoder_layer
(
x
,
x
,
encoder_hidden_states
,
encoder_hidden_states
,
encoder_padding_mask
,
encoder_padding_mask
,
...
@@ -836,10 +836,10 @@ class BartModel(PretrainedBartModel):
...
@@ -836,10 +836,10 @@ class BartModel(PretrainedBartModel):
)
)
assert
decoder_input_ids
is
not
None
assert
decoder_input_ids
is
not
None
if
encoder_outputs
is
None
:
if
encoder_outputs
is
None
:
encoder_outputs
=
self
.
encoder
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
encoder_outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
assert
isinstance
(
encoder_outputs
,
tuple
)
assert
isinstance
(
encoder_outputs
,
tuple
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
.
forward
(
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
decoder_input_ids
,
encoder_outputs
[
0
],
encoder_outputs
[
0
],
attention_mask
,
attention_mask
,
...
@@ -925,7 +925,7 @@ class BartForMaskedLM(PretrainedBartModel):
...
@@ -925,7 +925,7 @@ class BartForMaskedLM(PretrainedBartModel):
outputs = model(input_ids=input_ids, lm_labels=input_ids)
outputs = model(input_ids=input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
loss, prediction_scores = outputs[:2]
"""
"""
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
(
input_ids
,
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
...
@@ -933,7 +933,7 @@ class BartForMaskedLM(PretrainedBartModel):
...
@@ -933,7 +933,7 @@ class BartForMaskedLM(PretrainedBartModel):
decoder_attention_mask
=
decoder_attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
decoder_cached_states
=
decoder_cached_states
,
decoder_cached_states
=
decoder_cached_states
,
)
)
lm_logits
=
self
.
lm_head
.
forward
(
outputs
[
0
])
lm_logits
=
self
.
lm_head
(
outputs
[
0
])
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
loss_fct
=
nn
.
CrossEntropyLoss
()
loss_fct
=
nn
.
CrossEntropyLoss
()
...
@@ -1308,7 +1308,7 @@ class BartForSequenceClassification(PretrainedBartModel):
...
@@ -1308,7 +1308,7 @@ class BartForSequenceClassification(PretrainedBartModel):
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
(
input_ids
,
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment