Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
6296de82
Commit
6296de82
authored
Aug 28, 2018
by
Myle Ott
Browse files
Add --upsample-primary
parent
5852d3a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
fairseq/tasks/translation.py
fairseq/tasks/translation.py
+5
-1
No files found.
fairseq/tasks/translation.py
View file @
6296de82
...
@@ -41,6 +41,8 @@ class TranslationTask(FairseqTask):
...
@@ -41,6 +41,8 @@ class TranslationTask(FairseqTask):
help
=
'max number of tokens in the source sequence'
)
help
=
'max number of tokens in the source sequence'
)
parser
.
add_argument
(
'--max-target-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--max-target-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the target sequence'
)
help
=
'max number of tokens in the target sequence'
)
parser
.
add_argument
(
'--upsample-primary'
,
default
=
1
,
type
=
int
,
help
=
'amount to upsample primary dataset'
)
def
__init__
(
self
,
args
,
src_dict
,
tgt_dict
):
def
__init__
(
self
,
args
,
src_dict
,
tgt_dict
):
super
().
__init__
(
args
)
super
().
__init__
(
args
)
...
@@ -120,12 +122,14 @@ class TranslationTask(FairseqTask):
...
@@ -120,12 +122,14 @@ class TranslationTask(FairseqTask):
src_sizes
=
src_dataset
.
sizes
src_sizes
=
src_dataset
.
sizes
tgt_sizes
=
tgt_dataset
.
sizes
tgt_sizes
=
tgt_dataset
.
sizes
else
:
else
:
if
self
.
args
.
upsample_primary
>
1
:
src_datasets
.
extend
([
src_datasets
[
0
]]
*
(
self
.
args
.
upsample_primary
-
1
))
tgt_datasets
.
extend
([
tgt_datasets
[
0
]]
*
(
self
.
args
.
upsample_primary
-
1
))
src_dataset
=
ConcatDataset
(
src_datasets
)
src_dataset
=
ConcatDataset
(
src_datasets
)
tgt_dataset
=
ConcatDataset
(
tgt_datasets
)
tgt_dataset
=
ConcatDataset
(
tgt_datasets
)
src_sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
src_datasets
])
src_sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
src_datasets
])
tgt_sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
tgt_datasets
])
tgt_sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
tgt_datasets
])
self
.
datasets
[
split
]
=
LanguagePairDataset
(
self
.
datasets
[
split
]
=
LanguagePairDataset
(
src_dataset
,
src_sizes
,
self
.
src_dict
,
src_dataset
,
src_sizes
,
self
.
src_dict
,
tgt_dataset
,
tgt_sizes
,
self
.
tgt_dict
,
tgt_dataset
,
tgt_sizes
,
self
.
tgt_dict
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment