Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2fbfda0d
Commit
2fbfda0d
authored
Jul 25, 2018
by
Myle Ott
Browse files
Merge internal changes
parent
93fec886
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
67 additions
and
53 deletions
+67
-53
fairseq/data/__init__.py
fairseq/data/__init__.py
+1
-1
fairseq/models/fconv.py
fairseq/models/fconv.py
+9
-9
fairseq/models/fconv_self_att.py
fairseq/models/fconv_self_att.py
+7
-7
fairseq/models/lstm.py
fairseq/models/lstm.py
+7
-7
fairseq/models/transformer.py
fairseq/models/transformer.py
+8
-8
fairseq/optim/lr_scheduler/fixed_schedule.py
fairseq/optim/lr_scheduler/fixed_schedule.py
+1
-1
fairseq/options.py
fairseq/options.py
+10
-3
fairseq/tasks/fairseq_task.py
fairseq/tasks/fairseq_task.py
+3
-3
fairseq/trainer.py
fairseq/trainer.py
+21
-14
No files found.
fairseq/data/__init__.py
View file @
2fbfda0d
...
...
@@ -7,7 +7,7 @@
from
.dictionary
import
Dictionary
from
.fairseq_dataset
import
FairseqDataset
from
.indexed_dataset
import
IndexedInMemoryDataset
,
IndexedRawTextDataset
from
.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
,
IndexedRawTextDataset
# noqa: F401
from
.language_pair_dataset
import
LanguagePairDataset
from
.monolingual_dataset
import
MonolingualDataset
from
.token_block_dataset
import
TokenBlockDataset
...
...
fairseq/models/fconv.py
View file @
2fbfda0d
...
...
@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
}
def
reorder_encoder_out
(
self
,
encoder_out
_dict
,
new_order
):
if
encoder_out
_dict
[
'encoder_out'
]
is
not
None
:
encoder_out
_dict
[
'encoder_out'
]
=
(
encoder_out
_dict
[
'encoder_out'
][
0
].
index_select
(
0
,
new_order
),
encoder_out
_dict
[
'encoder_out'
][
1
].
index_select
(
0
,
new_order
),
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
if
encoder_out
[
'encoder_out'
]
is
not
None
:
encoder_out
[
'encoder_out'
]
=
(
encoder_out
[
'encoder_out'
][
0
].
index_select
(
0
,
new_order
),
encoder_out
[
'encoder_out'
][
1
].
index_select
(
0
,
new_order
),
)
if
encoder_out
_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out
_dict
[
'encoder_padding_mask'
]
=
\
encoder_out
_dict
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out
_dict
if
encoder_out
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out
[
'encoder_padding_mask'
]
=
\
encoder_out
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
...
...
fairseq/models/fconv_self_att.py
View file @
2fbfda0d
...
...
@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder):
'encoder_out'
:
(
x
,
y
),
}
def
reorder_encoder_out
(
self
,
encoder_out
_dict
,
new_order
):
encoder_out
_dict
[
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out
_dict
[
'encoder_out'
]
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
encoder_out
[
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out
[
'encoder_out'
]
)
if
'pretrained'
in
encoder_out
_dict
:
encoder_out
_dict
[
'pretrained'
][
'encoder_out'
]
=
tuple
(
if
'pretrained'
in
encoder_out
:
encoder_out
[
'pretrained'
][
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out
_dict
[
'pretrained'
][
'encoder_out'
]
for
eo
in
encoder_out
[
'pretrained'
][
'encoder_out'
]
)
return
encoder_out
_dict
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
...
...
fairseq/models/lstm.py
View file @
2fbfda0d
...
...
@@ -237,15 +237,15 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask'
:
encoder_padding_mask
if
encoder_padding_mask
.
any
()
else
None
}
def
reorder_encoder_out
(
self
,
encoder_out
_dict
,
new_order
):
encoder_out
_dict
[
'encoder_out'
]
=
tuple
(
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
encoder_out
[
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
1
,
new_order
)
for
eo
in
encoder_out
_dict
[
'encoder_out'
]
for
eo
in
encoder_out
[
'encoder_out'
]
)
if
encoder_out
_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out
_dict
[
'encoder_padding_mask'
]
=
\
encoder_out
_dict
[
'encoder_padding_mask'
].
index_select
(
1
,
new_order
)
return
encoder_out
_dict
if
encoder_out
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out
[
'encoder_padding_mask'
]
=
\
encoder_out
[
'encoder_padding_mask'
].
index_select
(
1
,
new_order
)
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
...
...
fairseq/models/transformer.py
View file @
2fbfda0d
...
...
@@ -225,14 +225,14 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
}
def
reorder_encoder_out
(
self
,
encoder_out
_dict
,
new_order
):
if
encoder_out
_dict
[
'encoder_out'
]
is
not
None
:
encoder_out
_dict
[
'encoder_out'
]
=
\
encoder_out
_dict
[
'encoder_out'
].
index_select
(
1
,
new_order
)
if
encoder_out
_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out
_dict
[
'encoder_padding_mask'
]
=
\
encoder_out
_dict
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out
_dict
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
if
encoder_out
[
'encoder_out'
]
is
not
None
:
encoder_out
[
'encoder_out'
]
=
\
encoder_out
[
'encoder_out'
].
index_select
(
1
,
new_order
)
if
encoder_out
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out
[
'encoder_padding_mask'
]
=
\
encoder_out
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
...
...
fairseq/optim/lr_scheduler/fixed_schedule.py
View file @
2fbfda0d
...
...
@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler):
super
().
__init__
(
args
,
optimizer
)
# set defaults
args
.
warmup_updates
=
getattr
(
args
,
'warmup_updates'
,
0
)
args
.
warmup_updates
=
getattr
(
args
,
'warmup_updates'
,
0
)
or
0
self
.
lr
=
args
.
lr
[
0
]
if
args
.
warmup_updates
>
0
:
...
...
fairseq/options.py
View file @
2fbfda0d
...
...
@@ -62,7 +62,7 @@ def eval_bool(x, default=False):
return
default
def
parse_args_and_arch
(
parser
,
input_args
=
None
):
def
parse_args_and_arch
(
parser
,
input_args
=
None
,
parse_known
=
False
):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
...
...
@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None):
TASK_REGISTRY
[
args
.
task
].
add_args
(
parser
)
# Parse a second time.
args
=
parser
.
parse_args
(
input_args
)
if
parse_known
:
args
,
extra
=
parser
.
parse_known_args
(
input_args
)
else
:
args
=
parser
.
parse_args
(
input_args
)
extra
=
None
# Post-process args.
if
hasattr
(
args
,
'lr'
):
...
...
@@ -104,7 +108,10 @@ def parse_args_and_arch(parser, input_args=None):
if
hasattr
(
args
,
'arch'
):
ARCH_CONFIG_REGISTRY
[
args
.
arch
](
args
)
return
args
if
parse_known
:
return
args
,
extra
else
:
return
args
def
get_parser
(
desc
,
default_task
=
'translation'
):
...
...
fairseq/tasks/fairseq_task.py
View file @
2fbfda0d
...
...
@@ -5,9 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
fairseq
import
criterions
,
models
from
fairseq.data
import
FairseqDataset
class
FairseqTask
(
object
):
"""
...
...
@@ -33,6 +30,7 @@ class FairseqTask(object):
def
dataset
(
self
,
split
):
"""Return a dataset split."""
from
fairseq.data
import
FairseqDataset
if
split
not
in
self
.
datasets
:
raise
KeyError
(
'Dataset not loaded: '
+
split
)
if
not
isinstance
(
self
.
datasets
[
split
],
FairseqDataset
):
...
...
@@ -40,9 +38,11 @@ class FairseqTask(object):
return
self
.
datasets
[
split
]
def
build_model
(
self
,
args
):
from
fairseq
import
models
return
models
.
build_model
(
args
,
self
)
def
build_criterion
(
self
,
args
):
from
fairseq
import
criterions
return
criterions
.
build_criterion
(
args
,
self
)
def
get_loss
(
self
,
model
,
criterion
,
sample
):
...
...
fairseq/trainer.py
View file @
2fbfda0d
...
...
@@ -140,6 +140,11 @@ class Trainer(object):
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
if
ooms_fwd
==
self
.
args
.
distributed_world_size
:
print
(
'| WARNING: OOM in all workers, skipping batch'
)
self
.
zero_grad
()
return
None
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
...
...
@@ -178,11 +183,6 @@ class Trainer(object):
return
None
# buffering updates
def
_forward
(
self
,
sample
,
eval
=
False
):
# prepare model and optimizer
if
eval
:
self
.
model
.
eval
()
else
:
self
.
model
.
train
()
loss
=
None
sample_size
=
0
logging_output
=
{
...
...
@@ -190,19 +190,26 @@ class Trainer(object):
'nsentences'
:
sample
[
'target'
].
size
(
0
)
if
sample
is
not
None
else
0
,
}
oom
=
0
if
sample
is
not
None
:
try
:
try
:
# prepare model and optimizer
if
eval
:
self
.
model
.
eval
()
else
:
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
if
sample
is
not
None
:
with
torch
.
no_grad
()
if
eval
else
contextlib
.
ExitStack
():
# calculate loss and sample size
loss
,
sample_size
,
logging_output_
=
self
.
task
.
get_loss
(
self
.
model
,
self
.
criterion
,
sample
)
logging_output
.
update
(
logging_output_
)
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory, skipping batch'
)
oom
=
1
loss
=
None
else
:
raise
e
except
RuntimeError
as
e
:
if
not
eval
and
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory, skipping batch'
)
oom
=
1
loss
=
None
else
:
raise
e
return
loss
,
sample_size
,
logging_output
,
oom
def
_backward
(
self
,
loss
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment