Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c031d010
Unverified
Commit
c031d010
authored
Sep 30, 2020
by
Amanpreet Singh
Committed by
GitHub
Sep 30, 2020
Browse files
Seq2SeqDataset: avoid passing src_lang everywhere (#7470)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
08939cfd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
23 deletions
+50
-23
examples/seq2seq/test_datasets.py
examples/seq2seq/test_datasets.py
+33
-0
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+17
-23
No files found.
examples/seq2seq/test_datasets.py
View file @
c031d010
...
@@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
...
@@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids1
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
0
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
ids2
=
set
(
DistributedSortishSampler
(
ds
,
256
,
num_replicas
=
2
,
rank
=
1
,
add_extra_examples
=
False
))
assert
ids1
.
intersection
(
ids2
)
==
set
()
assert
ids1
.
intersection
(
ids2
)
==
set
()
@
pytest
.
mark
.
parametrize
(
"tok_name"
,
[
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_dataset_kwargs
(
tok_name
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok_name
)
if
tok_name
==
MBART_TINY
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
,
src_lang
=
"EN"
,
tgt_lang
=
"FR"
,
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"src_lang"
in
kwargs
and
"tgt_lang"
in
kwargs
else
:
train_dataset
=
Seq2SeqDataset
(
tokenizer
,
data_dir
=
make_test_data_dir
(),
type_path
=
"train"
,
max_source_length
=
4
,
max_target_length
=
8
)
kwargs
=
train_dataset
.
dataset_kwargs
assert
"add_prefix_space"
not
in
kwargs
if
tok_name
!=
BART_TINY
else
"add_prefix_space"
in
kwargs
assert
len
(
kwargs
)
==
1
if
tok_name
==
BART_TINY
else
len
(
kwargs
)
==
0
examples/seq2seq/utils.py
View file @
c031d010
...
@@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
...
@@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return
loss
,
nll_loss
return
loss
,
nll_loss
def
encode_line
(
tokenizer
,
line
,
max_length
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
"""Only used by LegacyDataset"""
extra_kw
=
{
"add_prefix_space"
:
True
}
if
isinstance
(
tokenizer
,
BartTokenizer
)
else
{}
return
tokenizer
(
[
line
],
max_length
=
max_length
,
padding
=
"max_length"
if
pad_to_max_length
else
None
,
truncation
=
True
,
return_tensors
=
return_tensors
,
**
extra_kw
,
)
def
lmap
(
f
:
Callable
,
x
:
Iterable
)
->
List
:
def
lmap
(
f
:
Callable
,
x
:
Iterable
)
->
List
:
"""list(map(f, x))"""
"""list(map(f, x))"""
return
list
(
map
(
f
,
x
))
return
list
(
map
(
f
,
x
))
...
@@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
max_target_length
,
max_target_length
,
type_path
=
"train"
,
type_path
=
"train"
,
n_obs
=
None
,
n_obs
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
prefix
=
""
,
prefix
=
""
,
**
dataset_kwargs
):
):
super
().
__init__
()
super
().
__init__
()
self
.
src_file
=
Path
(
data_dir
).
joinpath
(
type_path
+
".source"
)
self
.
src_file
=
Path
(
data_dir
).
joinpath
(
type_path
+
".source"
)
...
@@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
if
n_obs
is
not
None
:
if
n_obs
is
not
None
:
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
src_lang
=
src_lang
self
.
dataset_kwargs
=
dataset_kwargs
self
.
tgt_lang
=
tgt_lang
dataset_kwargs
.
update
({
"add_prefix_space"
:
True
}
if
isinstance
(
self
.
tokenizer
,
BartTokenizer
)
else
{})
self
.
add_prefix_space
=
isinstance
(
self
.
tokenizer
,
BartTokenizer
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
src_lens
)
return
len
(
self
.
src_lens
)
...
@@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
source_inputs
=
encode_line
(
self
.
tokenizer
,
source_line
,
self
.
max_source_length
)
source_inputs
=
self
.
encode_line
(
self
.
tokenizer
,
source_line
,
self
.
max_source_length
)
target_inputs
=
encode_line
(
self
.
tokenizer
,
tgt_line
,
self
.
max_target_length
)
target_inputs
=
self
.
encode_line
(
self
.
tokenizer
,
tgt_line
,
self
.
max_target_length
)
source_ids
=
source_inputs
[
"input_ids"
].
squeeze
()
source_ids
=
source_inputs
[
"input_ids"
].
squeeze
()
target_ids
=
target_inputs
[
"input_ids"
].
squeeze
()
target_ids
=
target_inputs
[
"input_ids"
].
squeeze
()
...
@@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
"labels"
:
target_ids
,
"labels"
:
target_ids
,
}
}
def
encode_line
(
self
,
tokenizer
,
line
,
max_length
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
"""Only used by LegacyDataset"""
return
tokenizer
(
[
line
],
max_length
=
max_length
,
padding
=
"max_length"
if
pad_to_max_length
else
None
,
truncation
=
True
,
return_tensors
=
return_tensors
,
**
self
.
dataset_kwargs
,
)
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
masks
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
masks
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
...
@@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""Call prepare_seq2seq_batch."""
"""Call prepare_seq2seq_batch."""
batch_encoding
:
Dict
[
str
,
torch
.
Tensor
]
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
batch_encoding
:
Dict
[
str
,
torch
.
Tensor
]
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_lang
=
self
.
tgt_lang
,
max_length
=
self
.
max_source_length
,
max_length
=
self
.
max_source_length
,
max_target_length
=
self
.
max_target_length
,
max_target_length
=
self
.
max_target_length
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
**
self
.
dataset_kwargs
,
).
data
).
data
batch_encoding
[
"ids"
]
=
torch
.
tensor
([
x
[
"id"
]
for
x
in
batch
])
batch_encoding
[
"ids"
]
=
torch
.
tensor
([
x
[
"id"
]
for
x
in
batch
])
return
batch_encoding
return
batch_encoding
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment