Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a4fe8c99
Commit
a4fe8c99
authored
Sep 24, 2018
by
Sergey Edunov
Committed by
Myle Ott
Sep 25, 2018
Browse files
Add back secondary set
parent
535ca991
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
24 deletions
+30
-24
fairseq/tasks/translation.py
fairseq/tasks/translation.py
+30
-24
No files found.
fairseq/tasks/translation.py
View file @
a4fe8c99
...
...
@@ -45,7 +45,7 @@ class TranslationTask(FairseqTask):
@
staticmethod
def
add_args
(
parser
):
"""Add task-specific arguments to the parser."""
parser
.
add_argument
(
'data'
,
help
=
'path to data director
y
'
)
parser
.
add_argument
(
'data'
,
nargs
=
'+'
,
help
=
'path
(s)
to data director
ie(s)
'
)
parser
.
add_argument
(
'-s'
,
'--source-lang'
,
default
=
None
,
metavar
=
'SRC'
,
help
=
'source language'
)
parser
.
add_argument
(
'-t'
,
'--target-lang'
,
default
=
None
,
metavar
=
'TARGET'
,
...
...
@@ -80,13 +80,13 @@ class TranslationTask(FairseqTask):
# find language pair automatically
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
args
.
source_lang
,
args
.
target_lang
=
data_utils
.
infer_language_pair
(
args
.
data
)
args
.
source_lang
,
args
.
target_lang
=
data_utils
.
infer_language_pair
(
args
.
data
[
0
]
)
if
args
.
source_lang
is
None
or
args
.
target_lang
is
None
:
raise
Exception
(
'Could not infer language pair, please provide it explicitly'
)
# load dictionaries
src_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)))
tgt_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)))
src_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
[
0
]
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)))
tgt_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
[
0
]
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)))
assert
src_dict
.
pad
()
==
tgt_dict
.
pad
()
assert
src_dict
.
eos
()
==
tgt_dict
.
eos
()
assert
src_dict
.
unk
()
==
tgt_dict
.
unk
()
...
...
@@ -102,8 +102,8 @@ class TranslationTask(FairseqTask):
split (str): name of the split (e.g., train, valid, test)
"""
def
split_exists
(
split
,
src
,
tgt
,
lang
):
filename
=
os
.
path
.
join
(
self
.
args
.
data
,
'{}.{}-{}.{}'
.
format
(
split
,
src
,
tgt
,
lang
))
def
split_exists
(
split
,
src
,
tgt
,
lang
,
data_path
):
filename
=
os
.
path
.
join
(
data_path
,
'{}.{}-{}.{}'
.
format
(
split
,
src
,
tgt
,
lang
))
if
self
.
args
.
raw_text
and
IndexedRawTextDataset
.
exists
(
filename
):
return
True
elif
not
self
.
args
.
raw_text
and
IndexedInMemoryDataset
.
exists
(
filename
):
...
...
@@ -120,28 +120,34 @@ class TranslationTask(FairseqTask):
src_datasets
=
[]
tgt_datasets
=
[]
for
k
in
itertools
.
count
():
split_k
=
split
+
(
str
(
k
)
if
k
>
0
else
''
)
# infer langcode
src
,
tgt
=
self
.
args
.
source_lang
,
self
.
args
.
target_lang
if
split_exists
(
split_k
,
src
,
tgt
,
src
):
prefix
=
os
.
path
.
join
(
self
.
args
.
data
,
'{}.{}-{}.'
.
format
(
split_k
,
src
,
tgt
))
elif
split_exists
(
split_k
,
tgt
,
src
,
src
):
prefix
=
os
.
path
.
join
(
self
.
args
.
data
,
'{}.{}-{}.'
.
format
(
split_k
,
tgt
,
src
))
else
:
if
k
>
0
:
b
re
ak
data_paths
=
self
.
args
.
data
for
data_path
in
data_paths
:
for
k
in
itertools
.
count
():
split_k
=
split
+
(
str
(
k
)
if
k
>
0
else
''
)
# infer langcode
src
,
tgt
=
self
.
args
.
source_lang
,
self
.
args
.
target_lang
if
split_exists
(
split_k
,
src
,
tgt
,
src
,
data_path
):
prefix
=
os
.
path
.
join
(
data_path
,
'{}.{}-{}.'
.
format
(
split_k
,
src
,
tgt
))
el
if
split_exists
(
split_k
,
tgt
,
src
,
src
,
data_path
)
:
p
re
fix
=
os
.
path
.
join
(
data_path
,
'{}.{}-{}.'
.
format
(
split_k
,
tgt
,
src
))
else
:
raise
FileNotFoundError
(
'Dataset not found: {} ({})'
.
format
(
split
,
self
.
args
.
data
))
if
k
>
0
:
break
else
:
raise
FileNotFoundError
(
'Dataset not found: {} ({})'
.
format
(
split
,
data_path
))
src_datasets
.
append
(
indexed_dataset
(
prefix
+
src
,
self
.
src_dict
))
tgt_datasets
.
append
(
indexed_dataset
(
prefix
+
tgt
,
self
.
tgt_dict
))
print
(
'| {} {} {} examples'
.
format
(
data_path
,
split_k
,
len
(
src_datasets
[
-
1
])))
if
not
combine
:
break
src_datasets
.
append
(
indexed_dataset
(
prefix
+
src
,
self
.
src_dict
))
tgt_datasets
.
append
(
indexed_dataset
(
prefix
+
tgt
,
self
.
tgt_dict
))
print
(
'| {} {} {} examples'
.
format
(
self
.
args
.
data
,
split_k
,
len
(
src_datasets
[
-
1
])))
if
not
combine
:
break
assert
len
(
src_datasets
)
==
len
(
tgt_datasets
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment