Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1360daca
"docs/source/vscode:/vscode.git/clone" did not exist on "fa661ce749b0d14ae1999d1b097866248624a842"
Commit
1360daca
authored
Mar 05, 2020
by
sshleifer
Browse files
cleanup deltas
parent
810079de
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
8 deletions
+7
-8
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+2
-3
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+5
-5
No files found.
src/transformers/modeling_bart.py
View file @
1360daca
...
...
@@ -640,9 +640,8 @@ class SelfAttention(nn.Module):
reshaped
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
)
attn_weights
=
attn_weights
.
masked_fill
(
reshaped
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights_float
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_probs
=
F
.
dropout
(
attn_weights_float
,
p
=
self
.
dropout
,
training
=
self
.
training
,)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_probs
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
,)
assert
v
is
not
None
attn_output
=
torch
.
bmm
(
attn_probs
,
v
)
...
...
tests/test_modeling_bart.py
View file @
1360daca
...
...
@@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim
=
32
,
max_position_embeddings
=
48
,
)
lm_model
=
BartForMaskedLM
(
config
)
.
to
(
torch_device
)
context
=
_long_t
ensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]])
summary
=
_long_t
ensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]])
lm_model
=
BartForMaskedLM
(
config
)
context
=
torch
.
T
ensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]])
.
long
()
summary
=
torch
.
T
ensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]])
.
long
()
logits
,
enc_features
=
lm_model
.
forward
(
input_ids
=
context
,
decoder_input_ids
=
summary
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
def
test_generate_beam_search
(
self
):
input_ids
=
_long_t
ensor
([[
71
,
82
,
2
],
[
68
,
34
,
2
]])
input_ids
=
torch
.
T
ensor
([[
71
,
82
,
2
],
[
68
,
34
,
2
]])
.
long
()
config
=
BartConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
24
,
...
...
@@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase):
max_position_embeddings
=
48
,
output_past
=
True
,
)
lm_model
=
BartForMaskedLM
(
config
)
.
to
(
torch_device
)
lm_model
=
BartForMaskedLM
(
config
)
lm_model
.
eval
()
new_input_ids
=
lm_model
.
generate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment