Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
500be01c
Unverified
Commit
500be01c
authored
Oct 06, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 06, 2020
Browse files
[s2s] save first batch to json for debugging purposes (#6810)
parent
2b574e7c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
0 deletions
+20
-0
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+16
-0
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+4
-0
No files found.
examples/seq2seq/finetune.py
View file @
500be01c
...
@@ -33,6 +33,7 @@ from utils import (
...
@@ -33,6 +33,7 @@ from utils import (
lmap
,
lmap
,
pickle_save
,
pickle_save
,
save_git_info
,
save_git_info
,
save_json
,
use_task_specific_params
,
use_task_specific_params
,
)
)
...
@@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer):
self
.
dataset_class
=
(
self
.
dataset_class
=
(
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
)
)
self
.
already_saved_batch
=
False
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
...
@@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer):
...
@@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer):
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
save_readable_batch
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
List
[
str
]]:
"""A debugging utility"""
readable_batch
=
{
k
:
self
.
tokenizer
.
batch_decode
(
v
.
tolist
())
if
"mask"
not
in
k
else
v
.
shape
for
k
,
v
in
batch
.
items
()
}
save_json
(
readable_batch
,
Path
(
self
.
output_dir
)
/
"text_batch.json"
)
save_json
({
k
:
v
.
tolist
()
for
k
,
v
in
batch
.
items
()},
Path
(
self
.
output_dir
)
/
"tok_batch.json"
)
self
.
already_saved_batch
=
True
return
readable_batch
def
forward
(
self
,
input_ids
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
self
.
model
(
input_ids
,
**
kwargs
)
return
self
.
model
(
input_ids
,
**
kwargs
)
...
@@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer):
...
@@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer):
decoder_input_ids
=
self
.
model
.
_shift_right
(
tgt_ids
)
decoder_input_ids
=
self
.
model
.
_shift_right
(
tgt_ids
)
else
:
else
:
decoder_input_ids
=
shift_tokens_right
(
tgt_ids
,
pad_token_id
)
decoder_input_ids
=
shift_tokens_right
(
tgt_ids
,
pad_token_id
)
if
not
self
.
already_saved_batch
:
# This would be slightly better if it only happened on rank zero
batch
[
"decoder_input_ids"
]
=
decoder_input_ids
self
.
save_readable_batch
(
batch
)
outputs
=
self
(
src_ids
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
outputs
=
self
(
src_ids
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
lm_logits
=
outputs
[
0
]
lm_logits
=
outputs
[
0
]
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
500be01c
...
@@ -422,6 +422,10 @@ def test_finetune(model):
...
@@ -422,6 +422,10 @@ def test_finetune(model):
assert
bart
.
decoder
.
embed_tokens
==
bart
.
encoder
.
embed_tokens
assert
bart
.
decoder
.
embed_tokens
==
bart
.
encoder
.
embed_tokens
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
assert
bart
.
decoder
.
embed_tokens
==
bart
.
shared
example_batch
=
load_json
(
module
.
output_dir
/
"text_batch.json"
)
assert
isinstance
(
example_batch
,
dict
)
assert
len
(
example_batch
)
>=
4
def
test_finetune_extra_model_args
():
def
test_finetune_extra_model_args
():
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment