Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
af4b98ed
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9e8c494da78077a91071a00ab2b73717deda24be"
Unverified
Commit
af4b98ed
authored
Sep 21, 2020
by
Stas Bekman
Committed by
GitHub
Sep 21, 2020
Browse files
[s2s] adjust finetune + test to work with fsmt (#7263)
parent
8d562a2d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
9 deletions
+22
-9
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+12
-6
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+10
-3
No files found.
examples/seq2seq/finetune.py
View file @
af4b98ed
...
...
@@ -61,6 +61,8 @@ class SummarizationModule(BaseTransformer):
pickle_save
(
self
.
hparams
,
self
.
hparams_save_path
)
self
.
step_count
=
0
self
.
metrics
=
defaultdict
(
list
)
self
.
model_type
=
self
.
config
.
model_type
self
.
vocab_size
=
self
.
config
.
tgt_vocab_size
if
self
.
model_type
==
"fsmt"
else
self
.
config
.
vocab_size
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
...
...
@@ -106,14 +108,18 @@ class SummarizationModule(BaseTransformer):
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try
:
freeze_params
(
self
.
model
.
model
.
shared
)
if
self
.
model_type
==
"t5"
:
freeze_params
(
self
.
model
.
shared
)
for
d
in
[
self
.
model
.
encoder
,
self
.
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
elif
self
.
model_type
==
"fsmt"
:
for
d
in
[
self
.
model
.
model
.
encoder
,
self
.
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
except
AttributeError
:
freeze_params
(
self
.
model
.
shared
)
for
d
in
[
self
.
model
.
encoder
,
self
.
model
.
decoder
]:
else
:
freeze_params
(
self
.
model
.
model
.
shared
)
for
d
in
[
self
.
model
.
model
.
encoder
,
self
.
model
.
model
.
decoder
]:
freeze_params
(
d
.
embed_positions
)
freeze_params
(
d
.
embed_tokens
)
def
forward
(
self
,
input_ids
,
**
kwargs
):
...
...
@@ -140,7 +146,7 @@ class SummarizationModule(BaseTransformer):
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct
=
torch
.
nn
.
CrossEntropyLoss
(
ignore_index
=
pad_token_id
)
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
lm_logits
.
shape
[
-
1
]
==
self
.
vocab_size
loss
=
ce_loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
shape
[
-
1
]),
tgt_ids
.
view
(
-
1
))
else
:
lprobs
=
torch
.
nn
.
functional
.
log_softmax
(
lm_logits
,
dim
=-
1
)
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
af4b98ed
...
...
@@ -103,6 +103,7 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY
=
"sshleifer/bart-tiny-random"
MBART_TINY
=
"sshleifer/tiny-mbart"
MARIAN_TINY
=
"sshleifer/tiny-marian-en-de"
FSMT_TINY
=
"stas/tiny-wmt19-en-de"
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
...
...
@@ -374,11 +375,11 @@ def test_run_eval_search(model):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
],
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
],
)
def
test_finetune
(
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
]
else
"summarization"
task
=
"translation"
if
model
in
[
MBART_TINY
,
MARIAN_TINY
,
FSMT_TINY
]
else
"summarization"
args_d
[
"label_smoothing"
]
=
0.1
if
task
==
"translation"
else
0
tmp_dir
=
make_test_data_dir
()
...
...
@@ -407,7 +408,13 @@ def test_finetune(model):
lm_head
=
module
.
model
.
lm_head
assert
not
lm_head
.
weight
.
requires_grad
assert
(
lm_head
.
weight
==
input_embeds
.
weight
).
all
().
item
()
elif
model
==
FSMT_TINY
:
fsmt
=
module
.
model
.
model
embed_pos
=
fsmt
.
decoder
.
embed_positions
assert
not
embed_pos
.
weight
.
requires_grad
assert
not
fsmt
.
decoder
.
embed_tokens
.
weight
.
requires_grad
# check that embeds are not the same
assert
fsmt
.
decoder
.
embed_tokens
!=
fsmt
.
encoder
.
embed_tokens
else
:
bart
=
module
.
model
.
model
embed_pos
=
bart
.
decoder
.
embed_positions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment