Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2db1e2f4
Unverified
Commit
2db1e2f4
authored
Jun 18, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 18, 2020
Browse files
[cleanup] remove redundant code in SummarizationDataset (#5119)
parent
5f721ad6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
10 deletions
+2
-10
examples/summarization/utils.py
examples/summarization/utils.py
+2
-10
No files found.
examples/summarization/utils.py
View file @
2db1e2f4
...
@@ -13,8 +13,6 @@ from torch import nn
...
@@ -13,8 +13,6 @@ from torch import nn
from
torch.utils.data
import
Dataset
,
Sampler
from
torch.utils.data
import
Dataset
,
Sampler
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
BartTokenizer
def
encode_file
(
def
encode_file
(
tokenizer
,
tokenizer
,
...
@@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
...
@@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
prefix
=
""
,
prefix
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
tok_name
=
"T5"
if
not
isinstance
(
tokenizer
,
BartTokenizer
)
else
""
tok_name
=
tokenizer
.
__class__
.
__name__
.
lower
().
rstrip
(
"tokenizer"
)
self
.
source
=
encode_file
(
self
.
source
=
encode_file
(
tokenizer
,
tokenizer
,
os
.
path
.
join
(
data_dir
,
type_path
+
".source"
),
os
.
path
.
join
(
data_dir
,
type_path
+
".source"
),
...
@@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
...
@@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
prefix
=
prefix
,
prefix
=
prefix
,
tok_name
=
tok_name
,
tok_name
=
tok_name
,
)
)
if
type_path
==
"train"
:
tgt_path
=
os
.
path
.
join
(
data_dir
,
type_path
+
".target"
)
tgt_path
=
os
.
path
.
join
(
data_dir
,
type_path
+
".target"
)
else
:
tgt_path
=
os
.
path
.
join
(
data_dir
,
type_path
+
".target"
)
self
.
target
=
encode_file
(
self
.
target
=
encode_file
(
tokenizer
,
tgt_path
,
max_target_length
,
overwrite_cache
=
overwrite_cache
,
tok_name
=
tok_name
tokenizer
,
tgt_path
,
max_target_length
,
overwrite_cache
=
overwrite_cache
,
tok_name
=
tok_name
)
)
self
.
source
=
encode_file
(
tokenizer
,
os
.
path
.
join
(
data_dir
,
type_path
+
".source"
),
max_source_length
)
self
.
target
=
encode_file
(
tokenizer
,
os
.
path
.
join
(
data_dir
,
type_path
+
".target"
),
max_target_length
)
if
n_obs
is
not
None
:
if
n_obs
is
not
None
:
self
.
source
=
self
.
source
[:
n_obs
]
self
.
source
=
self
.
source
[:
n_obs
]
self
.
target
=
self
.
target
[:
n_obs
]
self
.
target
=
self
.
target
[:
n_obs
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment