Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c031d010
Unverified
Commit
c031d010
authored
Sep 30, 2020
by
Amanpreet Singh
Committed by
GitHub
Sep 30, 2020
Browse files
Seq2SeqDataset: avoid passing src_lang everywhere (#7470)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
08939cfd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
23 deletions
+50
-23
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+33
-0
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+17
-23
No files found.
examples/seq2seq/test_datasets.py
View file @
c031d010
...
@@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
...
@@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
src_lang
=
"EN"
,
tgt_lang
=
"FR"
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
examples/seq2seq/utils.py
View file @
c031d010
...
@@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
...
@@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return
loss
,
nll_loss
return
loss
,
nll_loss
def
encode_line
(
tokenizer
,
line
,
max_length
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
"""Only used by LegacyDataset"""
extra_kw
=
{
"add_prefix_space"
:
True
}
if
isinstance
(
tokenizer
,
BartTokenizer
)
else
{}
return
tokenizer
(
[
line
],
max_length
=
max_length
,
padding
=
"max_length"
if
pad_to_max_length
else
None
,
truncation
=
True
,
return_tensors
=
return_tensors
,
**
extra_kw
,
)
def
lmap
(
f
:
Callable
,
x
:
Iterable
)
->
List
:
def
lmap
(
f
:
Callable
,
x
:
Iterable
)
->
List
:
"""list(map(f, x))"""
"""list(map(f, x))"""
return
list
(
map
(
f
,
x
))
return
list
(
map
(
f
,
x
))
...
@@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
max_target_length
,
max_target_length
,
type_path
=
"train"
,
type_path
=
"train"
,
n_obs
=
None
,
n_obs
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
prefix
=
""
,
prefix
=
""
,
**
dataset_kwargs
):
):
super
().
__init__
()
super
().
__init__
()
self
.
src_file
=
Path
(
data_dir
).
joinpath
(
type_path
+
".source"
)
self
.
src_file
=
Path
(
data_dir
).
joinpath
(
type_path
+
".source"
)
...
@@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
if
n_obs
is
not
None
:
if
n_obs
is
not
None
:
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
src_lang
=
src_lang
self
.
dataset_kwargs
=
dataset_kwargs
self
.
tgt_lang
=
tgt_lang
dataset_kwargs
.
update
({
"add_prefix_space"
:
True
}
if
isinstance
(
self
.
tokenizer
,
BartTokenizer
)
else
{})
self
.
add_prefix_space
=
isinstance
(
self
.
tokenizer
,
BartTokenizer
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
src_lens
)
return
len
(
self
.
src_lens
)
...
@@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
source_inputs
=
encode_line
(
self
.
tokenizer
,
source_line
,
self
.
max_source_length
)
source_inputs
=
self
.
encode_line
(
self
.
tokenizer
,
source_line
,
self
.
max_source_length
)
target_inputs
=
encode_line
(
self
.
tokenizer
,
tgt_line
,
self
.
max_target_length
)
target_inputs
=
self
.
encode_line
(
self
.
tokenizer
,
tgt_line
,
self
.
max_target_length
)
source_ids
=
source_inputs
[
"input_ids"
].
squeeze
()
source_ids
=
source_inputs
[
"input_ids"
].
squeeze
()
target_ids
=
target_inputs
[
"input_ids"
].
squeeze
()
target_ids
=
target_inputs
[
"input_ids"
].
squeeze
()
...
@@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
"labels"
:
target_ids
,
"labels"
:
target_ids
,
}
}
def
encode_line
(
self
,
tokenizer
,
line
,
max_length
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
"""Only used by LegacyDataset"""
return
tokenizer
(
[
line
],
max_length
=
max_length
,
padding
=
"max_length"
if
pad_to_max_length
else
None
,
truncation
=
True
,
return_tensors
=
return_tensors
,
**
self
.
dataset_kwargs
,
)
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
masks
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
masks
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
...
@@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""Call prepare_seq2seq_batch."""
"""Call prepare_seq2seq_batch."""
batch_encoding
:
Dict
[
str
,
torch
.
Tensor
]
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
batch_encoding
:
Dict
[
str
,
torch
.
Tensor
]
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
tgt_lang
,
max_length
=
self
.
max_source_length
,
max_length
=
self
.
max_source_length
,
max_target_length
=
self
.
max_target_length
,
max_target_length
=
self
.
max_target_length
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
**
self
.
dataset_kwargs
,
).
data
).
data
batch_encoding
[
"ids"
]
=
torch
.
tensor
([
x
[
"id"
]
for
x
in
batch
])
batch_encoding
[
"ids"
]
=
torch
.
tensor
([
x
[
"id"
]
for
x
in
batch
])
return
batch_encoding
return
batch_encoding
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment