Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d0486c8b
Unverified
Commit
d0486c8b
authored
Jul 15, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 15, 2020
Browse files
[cleanup] T5 test, warnings (#5761)
parent
ec0a945c
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
97 deletions
+55
-97
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-3
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+10
-2
tests/test_modeling_t5.py
tests/test_modeling_t5.py
+44
-92
No files found.
examples/seq2seq/run_eval.py
View file @
d0486c8b
...
@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
...
@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
for
batch
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
for
batch
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
if
"t5"
in
model_name
:
if
"t5"
in
model_name
:
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
tokenizer
(
batch
,
max_length
=
1024
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
device
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
gen_kwargs
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
gen_kwargs
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
...
...
examples/seq2seq/utils.py
View file @
d0486c8b
...
@@ -2,6 +2,7 @@ import itertools
...
@@ -2,6 +2,7 @@ import itertools
import
json
import
json
import
os
import
os
import
pickle
import
pickle
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
from
typing
import
Callable
,
Dict
,
Iterable
,
List
...
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
...
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
return
iter
(
sort_idx
)
return
iter
(
sort_idx
)
logger
=
getLogger
(
__name__
)
def
use_task_specific_params
(
model
,
task
):
def
use_task_specific_params
(
model
,
task
):
# u
pdate config with summarization specific params
"""U
pdate config with summarization specific params
."""
task_specific_params
=
model
.
config
.
task_specific_params
task_specific_params
=
model
.
config
.
task_specific_params
if
task_specific_params
is
not
None
:
if
task_specific_params
is
not
None
:
model
.
config
.
update
(
task_specific_params
.
get
(
task
,
{}))
pars
=
task_specific_params
.
get
(
task
,
{})
logger
.
info
(
f
"using task specific params for
{
task
}
:
{
pars
}
"
)
model
.
config
.
update
(
pars
)
def
pickle_load
(
path
):
def
pickle_load
(
path
):
...
...
tests/test_modeling_t5.py
View file @
d0486c8b
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment