Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
500be01c
Unverified
Commit
500be01c
authored
Oct 06, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 06, 2020
Browse files
[s2s] save first batch to json for debugging purposes (#6810)
parent
2b574e7c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
0 deletions
+20
-0
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+16
-0
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+4
-0
No files found.
examples/seq2seq/finetune.py
View file @
500be01c
...
@@ -33,6 +33,7 @@ from utils import (
...
@@ -33,6 +33,7 @@ from utils import (
lmap
,
lmap
,
pickle_save
,
pickle_save
,
save_git_info
,
save_git_info
,
save_json
,
use_task_specific_params
,
use_task_specific_params
,
)
)
...
@@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer):
self
.
dataset_class
=
(
self
.
dataset_class
=
(
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
)
)
self
.
already_saved_batch
=
False
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
...
@@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer):
...
@@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer):
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
save_readable_batch
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
List
[
str
]]:
"""A debugging utility"""
readable_batch
=
{
k
:
self
.
tokenizer
.
batch_decode
(
v
.
tolist
())
if
"mask"
not
in
k
else
v
.
shape
for
k
,
v
in
batch
.
items
()
}
save_json
(
readable_batch
,
Path
(
self
.
output_dir
)
/
"text_batch.json"
)
save_json
({
k
:
v
.
tolist
()
for
k
,
v
in
batch
.
items
()},
Path
(
self
.
output_dir
)
/
"tok_batch.json"
)
self
.
already_saved_batch
=
True
return
readable_batch
def
forward
(
self
,
input_ids
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
self
.
model
(
input_ids
,
**
kwargs
)
return
self
.
model
(
input_ids
,
**
kwargs
)
...
@@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer):
...
@@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer):
decoder_input_ids
=
self
.
model
.
_shift_right
(
tgt_ids
)
decoder_input_ids
=
self
.
model
.
_shift_right
(
tgt_ids
)
else
:
else
:
decoder_input_ids
=
shift_tokens_right
(
tgt_ids
,
pad_token_id
)
decoder_input_ids
=
shift_tokens_right
(
tgt_ids
,
pad_token_id
)
if
not
self
.
already_saved_batch
:
# This would be slightly better if it only happened on rank zero
batch
[
"decoder_input_ids"
]
=
decoder_input_ids
self
.
save_readable_batch
(
batch
)
outputs
=
self
(
src_ids
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
outputs
=
self
(
src_ids
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
lm_logits
=
outputs
[
0
]
lm_logits
=
outputs
[
0
]
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
500be01c
...
@@ -422,6 +422,10 @@ def test_finetune(model):
...
@@ -422,6 +422,10 @@ def test_finetune(model):
assert
bart
.
decoder
.
embed_tokens
==
bart
.
encoder
.
embed_tokens
assert
bart
.
decoder
.
embed_tokens
==
bart
.
encoder
.
embed_tokens
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
example_batch
=
load_json
(
module
.
output_dir
/
"text_batch.json"
)
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
def
test_finetune_extra_model_args
():
def
test_finetune_extra_model_args
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment