Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d0486c8b
Unverified
Commit
d0486c8b
authored
Jul 15, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 15, 2020
Browse files
[cleanup] T5 test, warnings (#5761)
parent
ec0a945c
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
97 deletions
+55
-97
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-3
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+10
-2
tests/test_modeling_t5.py
tests/test_modeling_t5.py
+44
-92
No files found.
examples/seq2seq/run_eval.py
View file @
d0486c8b
...
...
@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
for
batch
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
if
"t5"
in
model_name
:
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
tokenizer
(
batch
,
max_length
=
1024
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
gen_kwargs
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
...
...
examples/seq2seq/utils.py
View file @
d0486c8b
...
...
@@ -2,6 +2,7 @@ import itertools
import
json
import
os
import
pickle
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
...
...
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
return
iter
(
sort_idx
)
logger
=
getLogger
(
__name__
)
def
use_task_specific_params
(
model
,
task
):
# u
pdate config with summarization specific params
"""U
pdate config with summarization specific params
."""
task_specific_params
=
model
.
config
.
task_specific_params
if
task_specific_params
is
not
None
:
model
.
config
.
update
(
task_specific_params
.
get
(
task
,
{}))
pars
=
task_specific_params
.
get
(
task
,
{})
logger
.
info
(
f
"using task specific params for
{
task
}
:
{
pars
}
"
)
model
.
config
.
update
(
pars
)
def
pickle_load
(
path
):
...
...
tests/test_modeling_t5.py
View file @
d0486c8b
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment