Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
mlperf_transformer_v0.7
Commits
9e8a8c05
Commit
9e8a8c05
authored
Oct 14, 2024
by
jerrrrry
Browse files
Initial commit
parents
Changes
209
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1733 additions
and
0 deletions
+1733
-0
implementations/pytorch/fairseq/optim/fairseq_optimizer.py
implementations/pytorch/fairseq/optim/fairseq_optimizer.py
+79
-0
implementations/pytorch/fairseq/optim/lr_scheduler/__init__.py
...mentations/pytorch/fairseq/optim/lr_scheduler/__init__.py
+39
-0
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-310.pyc
...q/optim/lr_scheduler/__pycache__/__init__.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-310.pyc
...cheduler/__pycache__/fairseq_lr_scheduler.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-310.pyc
...m/lr_scheduler/__pycache__/fixed_schedule.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc
.../__pycache__/inverse_square_root_schedule.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc
...cheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
...ytorch/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
+44
-0
implementations/pytorch/fairseq/optim/lr_scheduler/fixed_schedule.py
...ions/pytorch/fairseq/optim/lr_scheduler/fixed_schedule.py
+57
-0
implementations/pytorch/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
...airseq/optim/lr_scheduler/inverse_square_root_schedule.py
+75
-0
implementations/pytorch/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
...ytorch/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
+46
-0
implementations/pytorch/fairseq/optim/nag.py
implementations/pytorch/fairseq/optim/nag.py
+77
-0
implementations/pytorch/fairseq/optim/sgd.py
implementations/pytorch/fairseq/optim/sgd.py
+31
-0
implementations/pytorch/fairseq/options.py
implementations/pytorch/fairseq/options.py
+413
-0
implementations/pytorch/fairseq/progress_bar.py
implementations/pytorch/fairseq/progress_bar.py
+205
-0
implementations/pytorch/fairseq/sequence_generator.py
implementations/pytorch/fairseq/sequence_generator.py
+536
-0
implementations/pytorch/fairseq/sequence_scorer.py
implementations/pytorch/fairseq/sequence_scorer.py
+88
-0
implementations/pytorch/fairseq/tasks/__init__.py
implementations/pytorch/fairseq/tasks/__init__.py
+43
-0
implementations/pytorch/fairseq/tasks/__pycache__/__init__.cpython-310.pyc
...ytorch/fairseq/tasks/__pycache__/__init__.cpython-310.pyc
+0
-0
implementations/pytorch/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc
...ch/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc
+0
-0
No files found.
implementations/pytorch/fairseq/optim/fairseq_optimizer.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.optim
class
FairseqOptimizer
(
object
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
()
self
.
args
=
args
self
.
params
=
params
@
staticmethod
def
add_args
(
parser
):
"""Add optimizer-specific arguments to the parser."""
pass
@
property
def
optimizer
(
self
):
"""Return a torch.optim.optimizer.Optimizer instance."""
if
not
hasattr
(
self
,
'_optimizer'
):
raise
NotImplementedError
if
not
isinstance
(
self
.
_optimizer
,
torch
.
optim
.
Optimizer
):
raise
ValueError
(
'_optimizer must be an instance of torch.optim.Optimizer'
)
return
self
.
_optimizer
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise
NotImplementedError
def
get_lr
(
self
):
"""Return the current learning rate."""
return
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
def
set_lr
(
self
,
lr
):
"""Set the learning rate."""
for
param_group
in
self
.
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
"""Return the optimizer's state dict."""
return
self
.
optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self
.
optimizer
.
load_state_dict
(
state_dict
)
# override learning rate, momentum, etc. with latest values
for
group
in
self
.
optimizer
.
param_groups
:
group
.
update
(
self
.
optimizer_config
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step."""
return
self
.
optimizer
.
step
(
closure
)
def
zero_grad
(
self
):
"""Clears the gradients of all optimized parameters."""
for
group
in
self
.
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
return
self
.
optimizer
.
zero_grad
()
implementations/pytorch/fairseq/optim/lr_scheduler/__init__.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_lr_scheduler
import
FairseqLRScheduler
LR_SCHEDULER_REGISTRY
=
{}
def
build_lr_scheduler
(
args
,
optimizer
):
return
LR_SCHEDULER_REGISTRY
[
args
.
lr_scheduler
](
args
,
optimizer
)
def
register_lr_scheduler
(
name
):
"""Decorator to register a new LR scheduler."""
def
register_lr_scheduler_cls
(
cls
):
if
name
in
LR_SCHEDULER_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate LR scheduler ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
FairseqLRScheduler
):
raise
ValueError
(
'LR Scheduler ({}: {}) must extend FairseqLRScheduler'
.
format
(
name
,
cls
.
__name__
))
LR_SCHEDULER_REGISTRY
[
name
]
=
cls
return
cls
return
register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.optim.lr_scheduler.'
+
module
)
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
..
import
FairseqOptimizer
class
FairseqLRScheduler
(
object
):
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
()
if
not
isinstance
(
optimizer
,
FairseqOptimizer
):
raise
ValueError
(
'optimizer must be an instance of FairseqOptimizer'
)
self
.
args
=
args
self
.
optimizer
=
optimizer
self
.
best
=
None
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
pass
def
state_dict
(
self
):
"""Return the LR scheduler state dict."""
return
{
'best'
:
self
.
best
}
def
load_state_dict
(
self
,
state_dict
):
"""Load an LR scheduler state dict."""
self
.
best
=
state_dict
[
'best'
]
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
if
val_loss
is
not
None
:
if
self
.
best
is
None
:
self
.
best
=
val_loss
else
:
self
.
best
=
min
(
self
.
best
,
val_loss
)
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
return
self
.
optimizer
.
get_lr
()
implementations/pytorch/fairseq/optim/lr_scheduler/fixed_schedule.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'fixed'
)
class
FixedSchedule
(
FairseqLRScheduler
):
"""Decay the LR on a fixed schedule."""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
# set defaults
args
.
warmup_updates
=
getattr
(
args
,
'warmup_updates'
,
0
)
or
0
self
.
lr
=
args
.
lr
[
0
]
if
args
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.
/
args
.
warmup_updates
else
:
self
.
warmup_factor
=
1
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--force-anneal'
,
'--fa'
,
type
=
int
,
metavar
=
'N'
,
help
=
'force annealing at specified epoch'
)
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
def
get_next_lr
(
self
,
epoch
):
lrs
=
self
.
args
.
lr
if
self
.
args
.
force_anneal
is
None
or
epoch
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
next_lr
=
lrs
[
min
(
epoch
,
len
(
lrs
)
-
1
)]
else
:
# annneal based on lr_shrink
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lr_shrink
**
(
epoch
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
self
.
lr
=
self
.
get_next_lr
(
epoch
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
args
.
warmup_updates
>
0
and
num_updates
<=
self
.
args
.
warmup_updates
:
self
.
warmup_factor
=
num_updates
/
float
(
self
.
args
.
warmup_updates
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
implementations/pytorch/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'inverse_sqrt'
)
class
InverseSquareRootSchedule
(
FairseqLRScheduler
):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = decay_factor / sqrt(update_num)
where
decay_factor = args.lr * sqrt(args.warmup_updates)
"""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
'Cannot use a fixed learning rate schedule with inverse_sqrt.'
' Consider --lr-scheduler=fixed instead.'
)
warmup_end_lr
=
args
.
lr
[
0
]
if
args
.
warmup_init_lr
<
0
:
args
.
warmup_init_lr
=
warmup_end_lr
# linearly warmup for the first args.warmup_updates
self
.
lr_step
=
(
warmup_end_lr
-
args
.
warmup_init_lr
)
/
args
.
warmup_updates
# then, decay prop. to the inverse square root of the update number
self
.
decay_factor
=
warmup_end_lr
*
args
.
warmup_updates
**
0.5
# initial learning rate
self
.
lr
=
args
.
warmup_init_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--warmup-updates'
,
default
=
4000
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--warmup-init-lr'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate during warmup phase; default is args.lr'
)
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
num_updates
<
self
.
args
.
warmup_updates
:
self
.
lr
=
self
.
args
.
warmup_init_lr
+
num_updates
*
self
.
lr_step
else
:
self
.
lr
=
self
.
decay_factor
*
num_updates
**-
0.5
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
implementations/pytorch/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.optim.lr_scheduler
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'reduce_lr_on_plateau'
)
class
ReduceLROnPlateau
(
FairseqLRScheduler
):
"""Decay the LR by a factor every time the validation loss plateaus."""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self
.
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
self
.
optimizer
.
optimizer
,
patience
=
0
,
factor
=
args
.
lr_shrink
)
def
state_dict
(
self
):
"""Return the LR scheduler state dict."""
return
{
'best'
:
self
.
lr_scheduler
.
best
,
'last_epoch'
:
self
.
lr_scheduler
.
last_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
"""Load an LR scheduler state dict."""
self
.
lr_scheduler
.
best
=
state_dict
[
'best'
]
if
'last_epoch'
in
state_dict
:
self
.
lr_scheduler
.
last_epoch
=
state_dict
[
'last_epoch'
]
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
if
val_loss
is
not
None
:
self
.
lr_scheduler
.
step
(
val_loss
,
epoch
)
else
:
self
.
lr_scheduler
.
last_epoch
=
epoch
return
self
.
optimizer
.
get_lr
()
implementations/pytorch/fairseq/optim/nag.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
torch.optim.optimizer
import
Optimizer
,
required
from
.
import
FairseqOptimizer
,
register_optimizer
@
register_optimizer
(
'nag'
)
class
FairseqNAG
(
FairseqOptimizer
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
(
args
,
params
)
self
.
_optimizer
=
NAG
(
params
,
**
self
.
optimizer_config
)
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return
{
'lr'
:
self
.
args
.
lr
[
0
],
'momentum'
:
self
.
args
.
momentum
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
class
NAG
(
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
required
,
momentum
=
0
,
weight_decay
=
0
):
defaults
=
dict
(
lr
=
lr
,
lr_old
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
super
(
NAG
,
self
).
__init__
(
params
,
defaults
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
weight_decay
=
group
[
'weight_decay'
]
momentum
=
group
[
'momentum'
]
lr
=
group
[
'lr'
]
lr_old
=
group
.
get
(
'lr_old'
,
lr
)
lr_correct
=
lr
/
lr_old
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
d_p
=
p
.
grad
.
data
param_state
=
self
.
state
[
p
]
if
'momentum_buffer'
not
in
param_state
:
param_state
[
'momentum_buffer'
]
=
d_p
.
clone
().
zero_
()
buf
=
param_state
[
'momentum_buffer'
]
if
weight_decay
!=
0
:
p
.
data
.
mul_
(
1
-
lr
*
weight_decay
)
p
.
data
.
add_
(
momentum
*
momentum
*
lr_correct
,
buf
)
p
.
data
.
add_
(
-
(
1
+
momentum
)
*
lr
,
d_p
)
buf
.
mul_
(
momentum
*
lr_correct
).
add_
(
-
lr
,
d_p
)
group
[
'lr_old'
]
=
lr
return
loss
implementations/pytorch/fairseq/optim/sgd.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.optim
from
.
import
FairseqOptimizer
,
register_optimizer
@
register_optimizer
(
'sgd'
)
class
SGD
(
FairseqOptimizer
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
(
args
,
params
)
self
.
_optimizer
=
torch
.
optim
.
SGD
(
params
,
**
self
.
optimizer_config
)
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return
{
'lr'
:
self
.
args
.
lr
[
0
],
'momentum'
:
self
.
args
.
momentum
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
implementations/pytorch/fairseq/options.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
argparse
import
os
import
torch
from
fairseq.criterions
import
CRITERION_REGISTRY
from
fairseq.models
import
ARCH_MODEL_REGISTRY
,
ARCH_CONFIG_REGISTRY
from
fairseq.optim
import
OPTIMIZER_REGISTRY
from
fairseq.optim.lr_scheduler
import
LR_SCHEDULER_REGISTRY
from
fairseq.tasks
import
TASK_REGISTRY
def
get_training_parser
(
default_task
=
'translation'
):
parser
=
get_parser
(
'Trainer'
,
default_task
)
add_dataset_args
(
parser
,
train
=
True
,
gen
=
True
)
add_distributed_training_args
(
parser
)
add_model_args
(
parser
)
add_optimization_args
(
parser
)
add_checkpoint_args
(
parser
)
add_generation_args
(
parser
)
add_perf_args
(
parser
)
return
parser
def
get_generation_parser
(
interactive
=
False
,
default_task
=
'translation'
):
parser
=
get_parser
(
'Generation'
,
default_task
)
add_dataset_args
(
parser
,
gen
=
True
)
add_generation_args
(
parser
)
add_perf_args
(
parser
)
if
interactive
:
add_interactive_args
(
parser
)
return
parser
def
get_eval_lm_parser
(
default_task
=
'language_modeling'
):
parser
=
get_parser
(
'Evaluate Language Model'
,
default_task
)
add_dataset_args
(
parser
,
gen
=
True
)
add_eval_lm_args
(
parser
)
return
parser
def
eval_str_list
(
x
,
type
=
float
):
if
x
is
None
:
return
None
if
isinstance
(
x
,
str
):
x
=
eval
(
x
)
try
:
return
list
(
map
(
type
,
x
))
except
TypeError
:
return
[
type
(
x
)]
def
eval_bool
(
x
,
default
=
False
):
if
x
is
None
:
return
default
try
:
return
bool
(
eval
(
x
))
except
TypeError
:
return
default
def
parse_args_and_arch
(
parser
,
input_args
=
None
,
parse_known
=
False
):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args
,
_
=
parser
.
parse_known_args
(
input_args
)
# Add model-specific args to parser.
if
hasattr
(
args
,
'arch'
):
model_specific_group
=
parser
.
add_argument_group
(
'Model-specific configuration'
,
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default
=
argparse
.
SUPPRESS
,
)
ARCH_MODEL_REGISTRY
[
args
.
arch
].
add_args
(
model_specific_group
)
# Add *-specific args to parser.
if
hasattr
(
args
,
'criterion'
):
CRITERION_REGISTRY
[
args
.
criterion
].
add_args
(
parser
)
if
hasattr
(
args
,
'optimizer'
):
OPTIMIZER_REGISTRY
[
args
.
optimizer
].
add_args
(
parser
)
if
hasattr
(
args
,
'lr_scheduler'
):
LR_SCHEDULER_REGISTRY
[
args
.
lr_scheduler
].
add_args
(
parser
)
if
hasattr
(
args
,
'task'
):
TASK_REGISTRY
[
args
.
task
].
add_args
(
parser
)
# Parse a second time.
if
parse_known
:
args
,
extra
=
parser
.
parse_known_args
(
input_args
)
else
:
args
=
parser
.
parse_args
(
input_args
)
extra
=
None
# Post-process args.
if
hasattr
(
args
,
'lr'
):
args
.
lr
=
eval_str_list
(
args
.
lr
,
type
=
float
)
if
hasattr
(
args
,
'update_freq'
):
args
.
update_freq
=
eval_str_list
(
args
.
update_freq
,
type
=
int
)
if
hasattr
(
args
,
'max_sentences_valid'
)
and
args
.
max_sentences_valid
is
None
:
args
.
max_sentences_valid
=
args
.
max_sentences
# Apply architecture configuration.
if
hasattr
(
args
,
'arch'
):
ARCH_CONFIG_REGISTRY
[
args
.
arch
](
args
)
if
parse_known
:
return
args
,
extra
else
:
return
args
def
get_parser
(
desc
,
default_task
=
'translation'
):
parser
=
argparse
.
ArgumentParser
(
description
=
'Facebook AI Research Sequence-to-Sequence Toolkit -- '
+
desc
)
parser
.
add_argument
(
'--no-progress-bar'
,
action
=
'store_true'
,
help
=
'disable progress bar'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'log progress every N batches (when progress bar is disabled)'
)
parser
.
add_argument
(
'--log-format'
,
default
=
None
,
help
=
'log format to use'
,
choices
=
[
'json'
,
'none'
,
'simple'
,
'tqdm'
])
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'use FP16'
)
parser
.
add_argument
(
'--profile'
,
type
=
int
,
default
=
None
)
# Task definitions can be found under fairseq/tasks/
parser
.
add_argument
(
'--task'
,
metavar
=
'TASK'
,
default
=
default_task
,
choices
=
TASK_REGISTRY
.
keys
(),
help
=
'task: {} (default: {})'
.
format
(
', '
.
join
(
TASK_REGISTRY
.
keys
()),
default_task
)
)
return
parser
def
add_dataset_args
(
parser
,
train
=
False
,
gen
=
False
):
group
=
parser
.
add_argument_group
(
'Dataset and data loading'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'ignore too long or too short lines in valid and test set'
)
group
.
add_argument
(
'--max-tokens'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
group
.
add_argument
(
'--max-sentences'
,
'--batch-size'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
group
.
add_argument
(
'--source_lang'
,
'--source_lang'
,
type
=
str
,
metavar
=
'N'
,
help
=
'Source language'
)
group
.
add_argument
(
'--target_lang'
,
'--target_lang'
,
type
=
str
,
metavar
=
'N'
,
help
=
'Target language'
)
group
.
add_argument
(
'--bucket_growth_factor'
,
'--bucket_growth_factor'
,
type
=
float
,
metavar
=
'N'
,
help
=
'Bucket growth factor'
)
group
.
add_argument
(
'--raw_text'
,
action
=
'store_true'
,
help
=
'raw text'
)
group
.
add_argument
(
'--batching_scheme'
,
default
=
'reference'
,
help
=
'Batching Scheme'
,
choices
=
[
'v0p5'
,
'v0p5_better'
,
'v0p6'
,
'reference'
])
group
.
add_argument
(
'--batch_multiple_strategy'
,
default
=
'mult_of_sequences'
,
help
=
'The strategy to achieve a batch size that is multiple of some number.'
,
choices
=
[
'mult_of_sequences'
,
'pad_sequence_to_mult'
,
'dynamic'
])
if
train
:
group
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
],
help
=
'data subset to use for training (train, valid, test)'
)
group
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)'
)
group
.
add_argument
(
'--max-sentences-valid'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a validation batch'
' (defaults to --max-sentences)'
)
if
gen
:
group
.
add_argument
(
'--gen-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'data subset to generate (train, valid, test)'
)
group
.
add_argument
(
'--num-shards'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'shard generation over N shards'
)
group
.
add_argument
(
'--shard-id'
,
default
=
0
,
type
=
int
,
metavar
=
'ID'
,
help
=
'id of the shard to generate (id < num_shards)'
)
return
group
def
add_distributed_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Distributed training'
)
group
.
add_argument
(
'--distributed-world-size'
,
type
=
int
,
metavar
=
'N'
,
default
=
torch
.
cuda
.
device_count
(),
help
=
'total number of GPUs across all nodes (default: all visible GPUs)'
)
group
.
add_argument
(
'--distributed-rank'
,
default
=
0
,
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--local_rank'
,
default
=
os
.
getenv
(
'LOCAL_RANK'
,
0
),
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
type
=
str
,
help
=
'distributed backend'
)
group
.
add_argument
(
'--distributed-init-method'
,
default
=
None
,
type
=
str
,
help
=
'typically tcp://hostname:port that will be used to '
'establish initial connetion'
)
group
.
add_argument
(
'--distributed-port'
,
default
=-
1
,
type
=
int
,
help
=
'port number (not required if using --distributed-init-method)'
)
group
.
add_argument
(
'--device-id'
,
default
=
0
,
type
=
int
,
help
=
'which GPU to use (usually configured automatically)'
)
group
.
add_argument
(
'--enable-global-stats'
,
action
=
'store_true'
,
help
=
'enable global reduction of logging statistics for debugging'
)
return
group
def
add_optimization_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Optimization'
)
group
.
add_argument
(
'--max-epoch'
,
'--me'
,
default
=-
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified epoch'
)
group
.
add_argument
(
'--max-update'
,
'--mu'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified update'
)
group
.
add_argument
(
'--target-bleu'
,
default
=
0.0
,
type
=
float
,
metavar
=
'TARGET'
,
help
=
'force stop training after reaching target bleu'
)
group
.
add_argument
(
'--clip-norm'
,
default
=
25
,
type
=
float
,
metavar
=
'NORM'
,
help
=
'clip threshold of gradients'
)
group
.
add_argument
(
'--sentence-avg'
,
action
=
'store_true'
,
help
=
'normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)'
)
group
.
add_argument
(
'--update-freq'
,
default
=
'1'
,
metavar
=
'N'
,
help
=
'update parameters every N_i batches, when in epoch i'
)
# Optimizer definitions can be found under fairseq/optim/
group
.
add_argument
(
'--optimizer'
,
default
=
'nag'
,
metavar
=
'OPT'
,
choices
=
OPTIMIZER_REGISTRY
.
keys
(),
help
=
'optimizer: {} (default: nag)'
.
format
(
', '
.
join
(
OPTIMIZER_REGISTRY
.
keys
())))
group
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
'0.25'
,
metavar
=
'LR_1,LR_2,...,LR_N'
,
help
=
'learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)'
)
group
.
add_argument
(
'--momentum'
,
default
=
0.99
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum factor'
)
group
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
0.0
,
type
=
float
,
metavar
=
'WD'
,
help
=
'weight decay'
)
# Distributed weight update parameters
group
.
add_argument
(
'--distributed-weight-update'
,
'--dwu'
,
default
=
0
,
type
=
int
,
metavar
=
'DWU'
,
help
=
'select distributed weight update strategy'
)
group
.
add_argument
(
'--dwu-group-size'
,
'--dwugs'
,
default
=
0
,
type
=
int
,
metavar
=
'DWUGS'
,
help
=
'distributed weight update group size. If arg is 0, defaults to one node'
)
group
.
add_argument
(
'--dwu-num-blocks'
,
'--dwunb'
,
default
=
8
,
type
=
int
,
metavar
=
'DWUNB'
,
help
=
'number of blocks in dwu scheme'
)
group
.
add_argument
(
'--dwu-num-chunks'
,
'--dwunc'
,
default
=
8
,
type
=
int
,
metavar
=
'DWUNC'
,
help
=
'number of chunks in dwu scheme'
)
group
.
add_argument
(
'--dwu-num-rs-pg'
,
'--dwurspg'
,
default
=
2
,
type
=
int
,
metavar
=
'DWURSPG'
,
help
=
'number of reduction-scatter streams in dwu scheme'
)
group
.
add_argument
(
'--dwu-num-ar-pg'
,
'--dwuarpg'
,
default
=
4
,
type
=
int
,
metavar
=
'DWUARPG'
,
help
=
'number of all-reduce streams in dwu scheme'
)
group
.
add_argument
(
'--dwu-num-ag-pg'
,
'--dwuagpg'
,
default
=
2
,
type
=
int
,
metavar
=
'DWUAGPG'
,
help
=
'number of all-gather streams in dwu scheme'
)
group
.
add_argument
(
'--dwu-full-pipeline'
,
action
=
'store_true'
,
help
=
'whether to do full or partial pipeline'
)
group
.
add_argument
(
'--dwu-overlap-reductions'
,
action
=
'store_true'
,
help
=
'whether to overlap reductions with backprop'
)
group
.
add_argument
(
'--dwu-compute-L2-grad-norm'
,
action
=
'store_true'
,
help
=
'whether to compute L2 grad norm'
)
group
.
add_argument
(
'--dwu-flat-mt'
,
action
=
'store_true'
,
help
=
'whether to flatten gradients with multi tensor scale'
)
group
.
add_argument
(
'--dwu-e5m2-allgather'
,
action
=
'store_true'
,
help
=
'do allgather with e5m2 floats'
)
group
.
add_argument
(
'--dwu-do-not-flatten-model'
,
action
=
'store_true'
,
help
=
'whether it is allowed to flatten model parameters'
)
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group
.
add_argument
(
'--lr-scheduler'
,
default
=
'reduce_lr_on_plateau'
,
help
=
'learning rate scheduler: {} (default: reduce_lr_on_plateau)'
.
format
(
', '
.
join
(
LR_SCHEDULER_REGISTRY
.
keys
())))
group
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)'
)
group
.
add_argument
(
'--min-lr'
,
default
=
1e-5
,
type
=
float
,
metavar
=
'LR'
,
help
=
'minimum learning rate'
)
group
.
add_argument
(
'--min-loss-scale'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'D'
,
help
=
'minimum loss scale (for FP16 training)'
)
# Parallel backward + all-reduce optimization
group
.
add_argument
(
'--enable-parallel-backward-allred-opt'
,
action
=
'store_true'
,
help
=
'enable all-reduce of w-gradients in parallel with backward propagation (only for FP16 training)'
)
group
.
add_argument
(
'--parallel-backward-allred-cuda-nstreams'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'num of CUDA streams used for parallel all-reduce'
)
group
.
add_argument
(
'--parallel-backward-allred-opt-threshold'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'min num of contiguous gradient elements before all-reduce is triggered'
)
group
.
add_argument
(
'--enable-parallel-backward-allred-opt-correctness-check'
,
action
=
'store_true'
,
help
=
'compare w-gradient values obtained doing all-reduce in parallel vs. at the end'
)
group
.
add_argument
(
'--dataloader-num-workers'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'num subprocesses for train data loader'
)
group
.
add_argument
(
'--enable-dataloader-pin-memory'
,
action
=
'store_true'
,
help
=
'enable pin_memory for train data loader'
)
return
group
def
add_checkpoint_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Checkpointing'
)
group
.
add_argument
(
'--save-dir'
,
metavar
=
'DIR'
,
default
=
'checkpoints'
,
help
=
'path to save checkpoints'
)
group
.
add_argument
(
'--restore-file'
,
default
=
'checkpoint_last.pt'
,
help
=
'filename in save-dir from which to load checkpoint'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'save a checkpoint every N epochs'
)
group
.
add_argument
(
'--save-interval-updates'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'save a checkpoint (and validate) every N updates'
)
group
.
add_argument
(
'--keep-interval-updates'
,
type
=
int
,
default
=-
1
,
metavar
=
'N'
,
help
=
'keep last N checkpoints saved with --save-interval-updates'
)
group
.
add_argument
(
'--no-save'
,
action
=
'store_true'
,
help
=
'don
\'
t save models or checkpoints'
)
group
.
add_argument
(
'--no-epoch-checkpoints'
,
action
=
'store_true'
,
help
=
'only store last and best checkpoints'
)
group
.
add_argument
(
'--validate-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'validate every N epochs'
)
return
group
def
add_common_eval_args
(
group
):
group
.
add_argument
(
'--path'
,
metavar
=
'FILE'
,
help
=
'path(s) to model file(s), colon separated'
)
group
.
add_argument
(
'--remove-bpe'
,
nargs
=
'?'
,
const
=
'@@ '
,
default
=
None
,
help
=
'remove BPE tokens before scoring'
)
group
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'generate on CPU'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
help
=
'only print final scores'
)
def
add_eval_lm_args
(
parser
):
group
=
parser
.
add_argument_group
(
'LM Evaluation'
)
add_common_eval_args
(
group
)
group
.
add_argument
(
'--output-word-probs'
,
action
=
'store_true'
,
help
=
'if set, outputs words and their predicted log probabilities to standard output'
)
def
add_generation_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Generation'
)
add_common_eval_args
(
group
)
group
.
add_argument
(
'--beam'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
help
=
'beam size'
)
group
.
add_argument
(
'--nbest'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of hypotheses to output'
)
group
.
add_argument
(
'--max-len-a'
,
default
=
0
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
group
.
add_argument
(
'--max-len-b'
,
default
=
200
,
type
=
int
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
group
.
add_argument
(
'--min-len'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'minimum generation length'
))
group
.
add_argument
(
'--no-early-stop'
,
action
=
'store_true'
,
help
=
(
'continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'
))
group
.
add_argument
(
'--unnormalized'
,
action
=
'store_true'
,
help
=
'compare unnormalized hypothesis scores'
)
group
.
add_argument
(
'--no-beamable-mm'
,
action
=
'store_true'
,
help
=
'don
\'
t use BeamableMM in attention layers'
)
group
.
add_argument
(
'--lenpen'
,
default
=
1
,
type
=
float
,
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
group
.
add_argument
(
'--unkpen'
,
default
=
0
,
type
=
float
,
help
=
'unknown word penalty: <0 produces more unks, >0 produces fewer'
)
group
.
add_argument
(
'--replace-unk'
,
nargs
=
'?'
,
const
=
True
,
default
=
None
,
help
=
'perform unknown replacement (optionally with alignment dictionary)'
)
group
.
add_argument
(
'--score-reference'
,
action
=
'store_true'
,
help
=
'just score the reference translation'
)
group
.
add_argument
(
'--prefix-size'
,
default
=
0
,
type
=
int
,
metavar
=
'PS'
,
help
=
'initialize generation by target prefix of given length'
)
group
.
add_argument
(
'--sampling'
,
action
=
'store_true'
,
help
=
'sample hypotheses instead of using beam search'
)
group
.
add_argument
(
'--sampling-topk'
,
default
=-
1
,
type
=
int
,
metavar
=
'PS'
,
help
=
'sample from top K likely next words instead of all words'
)
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
'temperature for random sampling'
)
group
.
add_argument
(
'--print-alignment'
,
action
=
'store_true'
,
help
=
'if set, uses attention feedback to compute and print alignment to source tokens'
)
group
.
add_argument
(
'--model-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
help
=
'a dictionary used to override model args at generation that were used during model training'
)
group
.
add_argument
(
'--online-eval'
,
action
=
'store_true'
,
help
=
'score model at the end of epoch'
)
group
.
add_argument
(
'--log-translations'
,
action
=
'store_true'
,
help
=
'save translations generated by online eval '
)
group
.
add_argument
(
'--ignore-case'
,
action
=
'store_true'
,
help
=
'ignore case druing online eval'
)
return
group
def
add_interactive_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Interactive'
)
group
.
add_argument
(
'--buffer-size'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'read this many sentences into a buffer before processing them'
)
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Model configuration'
)
# Model definitions can be found under fairseq/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority)
group
.
add_argument
(
'--arch'
,
'-a'
,
default
=
'fconv'
,
metavar
=
'ARCH'
,
required
=
True
,
choices
=
ARCH_MODEL_REGISTRY
.
keys
(),
help
=
'model architecture: {} (default: fconv)'
.
format
(
', '
.
join
(
ARCH_MODEL_REGISTRY
.
keys
())),
)
# Criterion definitions can be found under fairseq/criterions/
group
.
add_argument
(
'--criterion'
,
default
=
'cross_entropy'
,
metavar
=
'CRIT'
,
choices
=
CRITERION_REGISTRY
.
keys
(),
help
=
'training criterion: {} (default: cross_entropy)'
.
format
(
', '
.
join
(
CRITERION_REGISTRY
.
keys
())),
)
return
group
def
add_perf_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Performance'
)
group
.
add_argument
(
'--multihead-attn-impl'
,
default
=
'default'
,
choices
=
[
'default'
,
'fast'
,
'fast_with_lyrnrm_and_dropoutadd'
],
help
=
'Multihead Attention implementations.'
)
parser
.
add_argument
(
'--time-step'
,
action
=
'store_true'
,
help
=
'Time the performance of a step.'
)
return
group
implementations/pytorch/fairseq/progress_bar.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
from
collections
import
OrderedDict
import
json
from
numbers
import
Number
import
sys
from
tqdm
import
tqdm
from
fairseq.meters
import
AverageMeter
def
build_progress_bar
(
args
,
iterator
,
epoch
=
None
,
prefix
=
None
,
default
=
'tqdm'
,
no_progress_bar
=
'none'
):
if
args
.
log_format
is
None
:
args
.
log_format
=
no_progress_bar
if
args
.
no_progress_bar
else
default
if
args
.
log_format
==
'tqdm'
and
not
sys
.
stderr
.
isatty
():
args
.
log_format
=
'simple'
if
args
.
log_format
==
'json'
:
bar
=
json_progress_bar
(
iterator
,
epoch
,
prefix
,
args
.
log_interval
)
elif
args
.
log_format
==
'none'
:
bar
=
noop_progress_bar
(
iterator
,
epoch
,
prefix
)
elif
args
.
log_format
==
'simple'
:
bar
=
simple_progress_bar
(
iterator
,
epoch
,
prefix
,
args
.
log_interval
)
elif
args
.
log_format
==
'tqdm'
:
bar
=
tqdm_progress_bar
(
iterator
,
epoch
,
prefix
)
else
:
raise
ValueError
(
'Unknown log format: {}'
.
format
(
args
.
log_format
))
return
bar
class
progress_bar
(
object
):
"""Abstract class for progress bars."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
self
.
iterable
=
iterable
self
.
epoch
=
epoch
self
.
prefix
=
''
if
epoch
is
not
None
:
self
.
prefix
+=
'| epoch {:03d}'
.
format
(
epoch
)
if
prefix
is
not
None
:
self
.
prefix
+=
' | {}'
.
format
(
prefix
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
exc
):
return
False
def
__iter__
(
self
):
raise
NotImplementedError
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
raise
NotImplementedError
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
raise
NotImplementedError
def
_str_commas
(
self
,
stats
):
return
', '
.
join
(
key
+
'='
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_str_pipes
(
self
,
stats
):
return
' | '
.
join
(
key
+
' '
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_format_stats
(
self
,
stats
):
postfix
=
OrderedDict
(
stats
)
# Preprocess stats according to datatype
for
key
in
postfix
.
keys
():
# Number: limit the length of the string
if
isinstance
(
postfix
[
key
],
Number
):
postfix
[
key
]
=
'{:g}'
.
format
(
postfix
[
key
])
# Meter: display both current and average value
elif
isinstance
(
postfix
[
key
],
AverageMeter
):
postfix
[
key
]
=
'{:.2f} ({:.2f})'
.
format
(
postfix
[
key
].
val
,
postfix
[
key
].
avg
)
# Else for any other type, try to get the string conversion
elif
not
isinstance
(
postfix
[
key
],
str
):
postfix
[
key
]
=
str
(
postfix
[
key
])
# Else if it's a string, don't need to preprocess anything
return
postfix
class
json_progress_bar
(
progress_bar
):
"""Log output in JSON format."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
stats
=
None
def
__iter__
(
self
):
size
=
float
(
len
(
self
.
iterable
))
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
self
.
stats
is
not
None
and
i
>
0
and
\
self
.
log_interval
is
not
None
and
i
%
self
.
log_interval
==
0
:
update
=
self
.
epoch
-
1
+
float
(
i
/
size
)
if
self
.
epoch
is
not
None
else
None
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
,
update
=
update
)
print
(
json
.
dumps
(
stats
),
flush
=
True
)
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
stats
=
stats
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
self
.
stats
=
stats
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
print
(
json
.
dumps
(
stats
),
flush
=
True
)
def
_format_stats
(
self
,
stats
,
epoch
=
None
,
update
=
None
):
postfix
=
OrderedDict
()
if
epoch
is
not
None
:
postfix
[
'epoch'
]
=
epoch
if
update
is
not
None
:
postfix
[
'update'
]
=
update
# Preprocess stats according to datatype
for
key
in
stats
.
keys
():
# Meter: display both current and average value
if
isinstance
(
stats
[
key
],
AverageMeter
):
postfix
[
key
]
=
stats
[
key
].
val
postfix
[
key
+
'_avg'
]
=
stats
[
key
].
avg
else
:
postfix
[
key
]
=
stats
[
key
]
return
postfix
class
noop_progress_bar
(
progress_bar
):
"""No logging."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
def
__iter__
(
self
):
for
obj
in
self
.
iterable
:
yield
obj
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
pass
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
pass
class
simple_progress_bar
(
progress_bar
):
"""A minimal logger for non-TTY environments."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
stats
=
None
def
__iter__
(
self
):
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
):
yield
obj
if
self
.
stats
is
not
None
and
i
>
0
and
\
self
.
log_interval
is
not
None
and
i
%
self
.
log_interval
==
0
:
postfix
=
self
.
_str_commas
(
self
.
stats
)
print
(
'{}: {:5d} / {:d} {}'
.
format
(
self
.
prefix
,
i
,
size
,
postfix
),
flush
=
True
)
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
stats
=
self
.
_format_stats
(
stats
)
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
print
(
'{} | {}'
.
format
(
self
.
prefix
,
postfix
),
flush
=
True
)
class
tqdm_progress_bar
(
progress_bar
):
"""Log to tqdm."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
tqdm
=
tqdm
(
iterable
,
self
.
prefix
,
leave
=
False
)
def
__iter__
(
self
):
return
iter
(
self
.
tqdm
)
def
log
(
self
,
stats
):
"""Log intermediate stats according to log_interval."""
self
.
tqdm
.
set_postfix
(
self
.
_format_stats
(
stats
),
refresh
=
False
)
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
self
.
tqdm
.
write
(
'{} | {}'
.
format
(
self
.
tqdm
.
desc
,
postfix
))
implementations/pytorch/fairseq/sequence_generator.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
import
torch
from
fairseq
import
utils
from
fairseq.models
import
FairseqIncrementalDecoder
class
SequenceGenerator
(
object
):
def
__init__
(
self
,
models
,
tgt_dict
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
None
,
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
,
retain_dropout
=
False
,
sampling
=
False
,
sampling_topk
=-
1
,
sampling_temperature
=
1
,
):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
stop_early: Stop generation immediately after we finalize beam_size
hypotheses, even though longer hypotheses might have better
normalized scores.
normalize_scores: Normalize scores by the length of the output.
"""
self
.
models
=
models
self
.
pad
=
tgt_dict
.
pad
()
self
.
eos
=
tgt_dict
.
eos
()
self
.
vocab_size
=
len
(
tgt_dict
)
self
.
beam_size
=
beam_size
self
.
minlen
=
minlen
max_decoder_len
=
min
(
m
.
max_decoder_positions
()
for
m
in
self
.
models
)
max_decoder_len
-=
1
# we define maxlen not including the EOS marker
self
.
maxlen
=
max_decoder_len
if
maxlen
is
None
else
min
(
maxlen
,
max_decoder_len
)
self
.
stop_early
=
stop_early
self
.
normalize_scores
=
normalize_scores
self
.
len_penalty
=
len_penalty
self
.
retain_dropout
=
retain_dropout
self
.
sampling
=
sampling
self
.
sampling_topk
=
sampling_topk
self
.
sampling_temperature
=
sampling_temperature
def
cuda
(
self
):
for
model
in
self
.
models
:
model
.
cuda
()
return
self
def
generate_batched_itr
(
self
,
data_itr
,
beam_size
=
None
,
maxlen_a
=
0.0
,
maxlen_b
=
None
,
cuda
=
False
,
timer
=
None
,
prefix_size
=
0
,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
if
maxlen_b
is
None
:
maxlen_b
=
self
.
maxlen
for
sample
in
data_itr
:
s
=
utils
.
move_to_cuda
(
sample
)
if
cuda
else
sample
if
'net_input'
not
in
s
:
continue
input
=
s
[
'net_input'
]
srclen
=
input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
timer
.
start
()
with
torch
.
no_grad
():
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
input
[
'src_lengths'
],
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
)
if
timer
is
not
None
:
timer
.
stop
(
sum
(
len
(
h
[
0
][
'tokens'
])
for
h
in
hypos
))
for
i
,
id
in
enumerate
(
s
[
'id'
].
data
):
# remove padding
src
=
utils
.
strip_pad
(
input
[
'src_tokens'
].
data
[
i
,
:],
self
.
pad
)
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
with
torch
.
no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
bsz
,
srclen
=
src_tokens
.
size
()
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
# the max beam size is the dictionary size - 1, since we never select pad
beam_size
=
beam_size
if
beam_size
is
not
None
else
self
.
beam_size
beam_size
=
min
(
beam_size
,
self
.
vocab_size
-
1
)
encoder_outs
=
[]
incremental_states
=
{}
for
model
in
self
.
models
:
if
not
self
.
retain_dropout
:
model
.
eval
()
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
incremental_states
[
model
]
=
{}
else
:
incremental_states
[
model
]
=
None
# compute the encoder output for each beam
encoder_out
=
model
.
encoder
(
src_tokens
.
repeat
(
1
,
beam_size
).
view
(
-
1
,
srclen
),
src_lengths
.
expand
(
beam_size
,
src_lengths
.
numel
()).
t
().
contiguous
().
view
(
-
1
),
)
encoder_outs
.
append
(
encoder_out
)
# initialize buffers
scores
=
src_tokens
.
data
.
new
(
bsz
*
beam_size
,
maxlen
+
1
).
float
().
fill_
(
0
)
scores_buf
=
scores
.
clone
()
tokens
=
src_tokens
.
data
.
new
(
bsz
*
beam_size
,
maxlen
+
2
).
fill_
(
self
.
pad
)
tokens_buf
=
tokens
.
clone
()
tokens
[:,
0
]
=
self
.
eos
attn
,
attn_buf
=
None
,
None
nonpad_idxs
=
None
# list of completed sentences
finalized
=
[[]
for
i
in
range
(
bsz
)]
finished
=
[
False
for
i
in
range
(
bsz
)]
worst_finalized
=
[{
'idx'
:
None
,
'score'
:
-
math
.
inf
}
for
i
in
range
(
bsz
)]
num_remaining_sent
=
bsz
# number of candidate hypos per step
cand_size
=
2
*
beam_size
# 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets
=
(
torch
.
arange
(
0
,
bsz
)
*
beam_size
).
unsqueeze
(
1
).
type_as
(
tokens
)
cand_offsets
=
torch
.
arange
(
0
,
cand_size
).
type_as
(
tokens
)
# helper function for allocating buffers on the fly
buffers
=
{}
def
buffer
(
name
,
type_of
=
tokens
):
# noqa
if
name
not
in
buffers
:
buffers
[
name
]
=
type_of
.
new
()
return
buffers
[
name
]
def
is_finished
(
sent
,
step
,
unfinalized_scores
=
None
):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert
len
(
finalized
[
sent
])
<=
beam_size
if
len
(
finalized
[
sent
])
==
beam_size
:
if
self
.
stop_early
or
step
==
maxlen
or
unfinalized_scores
is
None
:
return
True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score
=
unfinalized_scores
[
sent
].
max
()
if
self
.
normalize_scores
:
best_unfinalized_score
/=
((
maxlen
+
5
)
/
6
)
**
self
.
len_penalty
if
worst_finalized
[
sent
][
'score'
]
>=
best_unfinalized_score
:
return
True
return
False
def
finalize_hypos
(
step
,
bbsz_idx
,
eos_scores
,
unfinalized_scores
=
None
):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert
bbsz_idx
.
numel
()
==
eos_scores
.
numel
()
# clone relevant token and attention tensors
tokens_clone
=
tokens
.
index_select
(
0
,
bbsz_idx
)
tokens_clone
=
tokens_clone
[:,
1
:
step
+
2
]
# skip the first index, which is EOS
tokens_clone
[:,
step
]
=
self
.
eos
attn_clone
=
attn
.
index_select
(
0
,
bbsz_idx
)[:,
:,
1
:
step
+
2
]
if
attn
is
not
None
else
None
# compute scores per token position
pos_scores
=
scores
.
index_select
(
0
,
bbsz_idx
)[:,
:
step
+
1
]
pos_scores
[:,
step
]
=
eos_scores
# convert from cumulative to per-position scores
pos_scores
[:,
1
:]
=
pos_scores
[:,
1
:]
-
pos_scores
[:,
:
-
1
]
# normalize sentence-level scores
if
self
.
normalize_scores
:
eos_scores
/=
(((
step
+
1
)
+
5
)
/
6
)
**
self
.
len_penalty
cum_unfin
=
[]
prev
=
0
for
f
in
finished
:
if
f
:
prev
+=
1
else
:
cum_unfin
.
append
(
prev
)
sents_seen
=
set
()
for
i
,
(
idx
,
score
)
in
enumerate
(
zip
(
bbsz_idx
.
tolist
(),
eos_scores
.
tolist
())):
unfin_idx
=
idx
//
beam_size
sent
=
unfin_idx
+
cum_unfin
[
unfin_idx
]
sents_seen
.
add
((
sent
,
unfin_idx
))
def
get_hypo
():
if
attn_clone
is
not
None
:
# remove padding tokens from attn scores
hypo_attn
=
attn_clone
[
i
][
nonpad_idxs
[
sent
]]
_
,
alignment
=
hypo_attn
.
max
(
dim
=
0
)
else
:
hypo_attn
=
None
alignment
=
None
return
{
'tokens'
:
tokens_clone
[
i
],
'score'
:
score
,
'attention'
:
hypo_attn
,
# src_len x tgt_len
'alignment'
:
alignment
,
'positional_scores'
:
pos_scores
[
i
],
}
if
len
(
finalized
[
sent
])
<
beam_size
:
finalized
[
sent
].
append
(
get_hypo
())
elif
not
self
.
stop_early
and
score
>
worst_finalized
[
sent
][
'score'
]:
# replace worst hypo for this sentence with new/better one
worst_idx
=
worst_finalized
[
sent
][
'idx'
]
if
worst_idx
is
not
None
:
finalized
[
sent
][
worst_idx
]
=
get_hypo
()
# find new worst finalized hypo for this sentence
idx
,
s
=
min
(
enumerate
(
finalized
[
sent
]),
key
=
lambda
r
:
r
[
1
][
'score'
])
worst_finalized
[
sent
]
=
{
'score'
:
s
[
'score'
],
'idx'
:
idx
,
}
newly_finished
=
[]
for
sent
,
unfin_idx
in
sents_seen
:
# check termination conditions for this sentence
if
not
finished
[
sent
]
and
is_finished
(
sent
,
step
,
unfinalized_scores
):
finished
[
sent
]
=
True
newly_finished
.
append
(
unfin_idx
)
return
newly_finished
reorder_state
=
None
batch_idxs
=
None
for
step
in
range
(
maxlen
+
1
):
# one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if
reorder_state
is
not
None
:
if
batch_idxs
is
not
None
:
# update beam indices to take into account removed sentences
corr
=
batch_idxs
-
torch
.
arange
(
batch_idxs
.
numel
()).
type_as
(
batch_idxs
)
reorder_state
.
view
(
-
1
,
beam_size
).
add_
(
corr
.
unsqueeze
(
-
1
)
*
beam_size
)
for
i
,
model
in
enumerate
(
self
.
models
):
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
reorder_incremental_state
(
incremental_states
[
model
],
reorder_state
)
encoder_outs
[
i
]
=
model
.
encoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# only the first beam
probs
=
probs
.
unfold
(
0
,
1
,
beam_size
).
squeeze
(
2
).
contiguous
()
scores
=
scores
.
type_as
(
probs
)
scores_buf
=
scores_buf
.
type_as
(
probs
)
elif
not
self
.
sampling
:
# make probs contain cumulative scores for each hypothesis
probs
.
add_
(
scores
[:,
step
-
1
].
view
(
-
1
,
1
))
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
# Record attention scores
if
avg_attn_scores
is
not
None
:
if
attn
is
None
:
attn
=
scores
.
new
(
bsz
*
beam_size
,
src_tokens
.
size
(
1
),
maxlen
+
2
)
attn_buf
=
attn
.
clone
()
nonpad_idxs
=
src_tokens
.
ne
(
self
.
pad
)
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_indices
=
buffer
(
'cand_indices'
)
cand_beams
=
buffer
(
'cand_beams'
)
eos_bbsz_idx
=
buffer
(
'eos_bbsz_idx'
)
eos_scores
=
buffer
(
'eos_scores'
,
type_of
=
scores
)
if
step
<
maxlen
:
if
prefix_tokens
is
not
None
and
step
<
prefix_tokens
.
size
(
1
):
probs_slice
=
probs
.
view
(
bsz
,
-
1
,
probs
.
size
(
-
1
))[:,
0
,
:]
cand_scores
=
torch
.
gather
(
probs_slice
,
dim
=
1
,
index
=
prefix_tokens
[:,
step
].
view
(
-
1
,
1
).
data
).
expand
(
-
1
,
cand_size
)
cand_indices
=
prefix_tokens
[:,
step
].
view
(
-
1
,
1
).
expand
(
bsz
,
cand_size
).
data
cand_beams
.
resize_as_
(
cand_indices
).
fill_
(
0
)
elif
self
.
sampling
:
assert
self
.
pad
==
1
,
'sampling assumes the first two symbols can be ignored'
if
self
.
sampling_topk
>
0
:
values
,
indices
=
probs
[:,
2
:].
topk
(
self
.
sampling_topk
)
exp_probs
=
values
.
div_
(
self
.
sampling_temperature
).
exp
()
if
step
==
0
:
torch
.
multinomial
(
exp_probs
,
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
else
:
torch
.
multinomial
(
exp_probs
,
1
,
replacement
=
True
,
out
=
cand_indices
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
torch
.
gather
(
indices
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
else
:
exp_probs
=
probs
.
div_
(
self
.
sampling_temperature
).
exp_
().
view
(
-
1
,
self
.
vocab_size
)
if
step
==
0
:
# we exclude the first two vocab items, one of which is pad
torch
.
multinomial
(
exp_probs
[:,
2
:],
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
else
:
torch
.
multinomial
(
exp_probs
[:,
2
:],
1
,
replacement
=
True
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
cand_scores
.
log_
()
cand_indices
=
cand_indices
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
cand_scores
=
cand_scores
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
if
step
==
0
:
cand_beams
=
torch
.
zeros
(
bsz
,
cand_size
).
type_as
(
cand_indices
)
else
:
cand_beams
=
torch
.
arange
(
0
,
beam_size
).
repeat
(
bsz
,
2
).
type_as
(
cand_indices
)
# make scores cumulative
cand_scores
.
add_
(
torch
.
gather
(
scores
[:,
step
-
1
].
view
(
bsz
,
beam_size
),
dim
=
1
,
index
=
cand_beams
,
)
)
else
:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch
.
topk
(
probs
.
view
(
bsz
,
-
1
),
k
=
min
(
cand_size
,
probs
.
view
(
bsz
,
-
1
).
size
(
1
)
-
1
),
# -1 so we never select pad
out
=
(
cand_scores
,
cand_indices
),
)
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
,
rounding_mode
=
"trunc"
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
else
:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
torch
.
sort
(
probs
[:,
self
.
eos
],
descending
=
True
,
out
=
(
eos_scores
,
eos_bbsz_idx
),
)
num_remaining_sent
-=
len
(
finalize_hypos
(
step
,
eos_bbsz_idx
,
eos_scores
))
assert
num_remaining_sent
==
0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx
=
cand_beams
.
add
(
bbsz_offsets
)
# finalize hypotheses that end in eos
eos_mask
=
cand_indices
.
eq
(
self
.
eos
)
finalized_sents
=
set
()
if
step
>=
self
.
minlen
:
# only consider eos when it's among the top beam_size indices
torch
.
masked_select
(
cand_bbsz_idx
[:,
:
beam_size
],
mask
=
eos_mask
[:,
:
beam_size
],
out
=
eos_bbsz_idx
,
)
if
eos_bbsz_idx
.
numel
()
>
0
:
torch
.
masked_select
(
cand_scores
[:,
:
beam_size
],
mask
=
eos_mask
[:,
:
beam_size
],
out
=
eos_scores
,
)
finalized_sents
=
finalize_hypos
(
step
,
eos_bbsz_idx
,
eos_scores
,
cand_scores
)
num_remaining_sent
-=
len
(
finalized_sents
)
assert
num_remaining_sent
>=
0
if
num_remaining_sent
==
0
:
break
assert
step
<
maxlen
if
len
(
finalized_sents
)
>
0
:
new_bsz
=
bsz
-
len
(
finalized_sents
)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask
=
torch
.
ones
(
bsz
).
type_as
(
cand_indices
)
batch_mask
[
cand_indices
.
new
(
finalized_sents
)]
=
0
batch_idxs
=
batch_mask
.
nonzero
().
squeeze
(
-
1
)
eos_mask
=
eos_mask
[
batch_idxs
]
cand_beams
=
cand_beams
[
batch_idxs
]
bbsz_offsets
.
resize_
(
new_bsz
,
1
)
cand_bbsz_idx
=
cand_beams
.
add
(
bbsz_offsets
)
cand_scores
=
cand_scores
[
batch_idxs
]
cand_indices
=
cand_indices
[
batch_idxs
]
if
prefix_tokens
is
not
None
:
prefix_tokens
=
prefix_tokens
[
batch_idxs
]
scores
=
scores
.
view
(
bsz
,
-
1
)[
batch_idxs
].
view
(
new_bsz
*
beam_size
,
-
1
)
scores_buf
.
resize_as_
(
scores
)
tokens
=
tokens
.
view
(
bsz
,
-
1
)[
batch_idxs
].
view
(
new_bsz
*
beam_size
,
-
1
)
tokens_buf
.
resize_as_
(
tokens
)
if
attn
is
not
None
:
attn
=
attn
.
view
(
bsz
,
-
1
)[
batch_idxs
].
view
(
new_bsz
*
beam_size
,
attn
.
size
(
1
),
-
1
)
attn_buf
.
resize_as_
(
attn
)
bsz
=
new_bsz
else
:
batch_idxs
=
None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask
=
buffer
(
'active_mask'
)
torch
.
add
(
eos_mask
.
type_as
(
cand_offsets
)
*
cand_size
,
cand_offsets
[:
eos_mask
.
size
(
1
)],
out
=
active_mask
,
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos
,
_ignore
=
buffer
(
'active_hypos'
),
buffer
(
'_ignore'
)
torch
.
topk
(
active_mask
,
k
=
beam_size
,
dim
=
1
,
largest
=
False
,
out
=
(
_ignore
,
active_hypos
)
)
active_bbsz_idx
=
buffer
(
'active_bbsz_idx'
)
torch
.
gather
(
cand_bbsz_idx
,
dim
=
1
,
index
=
active_hypos
,
out
=
active_bbsz_idx
,
)
active_scores
=
torch
.
gather
(
cand_scores
,
dim
=
1
,
index
=
active_hypos
,
out
=
scores
[:,
step
].
view
(
bsz
,
beam_size
),
)
active_bbsz_idx
=
active_bbsz_idx
.
view
(
-
1
)
active_scores
=
active_scores
.
view
(
-
1
)
# copy tokens and scores for active hypotheses
torch
.
index_select
(
tokens
[:,
:
step
+
1
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
tokens_buf
[:,
:
step
+
1
],
)
torch
.
gather
(
cand_indices
,
dim
=
1
,
index
=
active_hypos
,
out
=
tokens_buf
.
view
(
bsz
,
beam_size
,
-
1
)[:,
:,
step
+
1
],
)
if
step
>
0
:
torch
.
index_select
(
scores
[:,
:
step
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
scores_buf
[:,
:
step
],
)
torch
.
gather
(
cand_scores
,
dim
=
1
,
index
=
active_hypos
,
out
=
scores_buf
.
view
(
bsz
,
beam_size
,
-
1
)[:,
:,
step
],
)
# copy attention for active hypotheses
if
attn
is
not
None
:
torch
.
index_select
(
attn
[:,
:,
:
step
+
2
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
attn_buf
[:,
:,
:
step
+
2
],
)
# swap buffers
tokens
,
tokens_buf
=
tokens_buf
,
tokens
scores
,
scores_buf
=
scores_buf
,
scores
if
attn
is
not
None
:
attn
,
attn_buf
=
attn_buf
,
attn
# reorder incremental state in decoder
reorder_state
=
active_bbsz_idx
# sort by score descending
for
sent
in
range
(
len
(
finalized
)):
finalized
[
sent
]
=
sorted
(
finalized
[
sent
],
key
=
lambda
r
:
r
[
'score'
],
reverse
=
True
)
return
finalized
def
_decode
(
self
,
tokens
,
encoder_outs
,
incremental_states
):
if
len
(
self
.
models
)
==
1
:
return
self
.
_decode_one
(
tokens
,
self
.
models
[
0
],
encoder_outs
[
0
],
incremental_states
,
log_probs
=
True
)
avg_probs
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
probs
,
attn
=
self
.
_decode_one
(
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
=
False
)
if
avg_probs
is
None
:
avg_probs
=
probs
else
:
avg_probs
.
add_
(
probs
)
if
attn
is
not
None
:
if
avg_attn
is
None
:
avg_attn
=
attn
else
:
avg_attn
.
add_
(
attn
)
avg_probs
.
div_
(
len
(
self
.
models
))
avg_probs
.
log_
()
if
avg_attn
is
not
None
:
avg_attn
.
div_
(
len
(
self
.
models
))
return
avg_probs
,
avg_attn
def
_decode_one
(
self
,
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
):
with
torch
.
no_grad
():
if
incremental_states
[
model
]
is
not
None
:
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
,
incremental_state
=
incremental_states
[
model
]))
else
:
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
))
decoder_out
[
0
]
=
decoder_out
[
0
][:,
-
1
,
:]
attn
=
decoder_out
[
1
]
if
attn
is
not
None
:
attn
=
attn
[:,
-
1
,
:]
probs
=
model
.
get_normalized_probs
(
decoder_out
,
log_probs
=
log_probs
)
return
probs
,
attn
implementations/pytorch/fairseq/sequence_scorer.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
from
fairseq
import
utils
class
SequenceScorer
(
object
):
"""Scores the target for a given source sentence."""
def
__init__
(
self
,
models
,
tgt_dict
):
self
.
models
=
models
self
.
pad
=
tgt_dict
.
pad
()
def
cuda
(
self
):
for
model
in
self
.
models
:
model
.
cuda
()
return
self
def
score_batched_itr
(
self
,
data_itr
,
cuda
=
False
,
timer
=
None
):
"""Iterate over a batched dataset and yield scored translations."""
for
sample
in
data_itr
:
s
=
utils
.
move_to_cuda
(
sample
)
if
cuda
else
sample
if
timer
is
not
None
:
timer
.
start
()
pos_scores
,
attn
=
self
.
score
(
s
)
for
i
,
id
in
enumerate
(
s
[
'id'
].
data
):
# remove padding from ref
src
=
utils
.
strip_pad
(
s
[
'net_input'
][
'src_tokens'
].
data
[
i
,
:],
self
.
pad
)
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
tgt_len
=
ref
.
numel
()
pos_scores_i
=
pos_scores
[
i
][:
tgt_len
]
score_i
=
pos_scores_i
.
sum
()
/
tgt_len
if
attn
is
not
None
:
attn_i
=
attn
[
i
]
_
,
alignment
=
attn_i
.
max
(
dim
=
0
)
else
:
attn_i
=
alignment
=
None
hypos
=
[{
'tokens'
:
ref
,
'score'
:
score_i
,
'attention'
:
attn_i
,
'alignment'
:
alignment
,
'positional_scores'
:
pos_scores_i
,
}]
if
timer
is
not
None
:
timer
.
stop
(
s
[
'ntokens'
])
# return results in the same format as SequenceGenerator
yield
id
,
src
,
ref
,
hypos
def
score
(
self
,
sample
):
"""Score a batch of translations."""
net_input
=
sample
[
'net_input'
]
# compute scores for each model in the ensemble
avg_probs
=
None
avg_attn
=
None
for
model
in
self
.
models
:
with
torch
.
no_grad
():
model
.
eval
()
decoder_out
=
model
.
forward
(
**
net_input
)
attn
=
decoder_out
[
1
]
probs
=
model
.
get_normalized_probs
(
decoder_out
,
log_probs
=
False
,
sample
=
sample
).
data
if
avg_probs
is
None
:
avg_probs
=
probs
else
:
avg_probs
.
add_
(
probs
)
if
attn
is
not
None
:
attn
=
attn
.
data
if
avg_attn
is
None
:
avg_attn
=
attn
else
:
avg_attn
.
add_
(
attn
)
avg_probs
.
div_
(
len
(
self
.
models
))
avg_probs
.
log_
()
if
avg_attn
is
not
None
:
avg_attn
.
div_
(
len
(
self
.
models
))
avg_probs
=
avg_probs
.
gather
(
dim
=
2
,
index
=
sample
[
'target'
].
data
.
unsqueeze
(
-
1
),
)
return
avg_probs
.
squeeze
(
2
),
avg_attn
implementations/pytorch/fairseq/tasks/__init__.py
0 → 100644
View file @
9e8a8c05
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_task
import
FairseqTask
TASK_REGISTRY
=
{}
TASK_CLASS_NAMES
=
set
()
def
setup_task
(
args
):
return
TASK_REGISTRY
[
args
.
task
].
setup_task
(
args
)
def
register_task
(
name
):
"""Decorator to register a new task."""
def
register_task_cls
(
cls
):
if
name
in
TASK_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate task ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
FairseqTask
):
raise
ValueError
(
'Task ({}: {}) must extend FairseqTask'
.
format
(
name
,
cls
.
__name__
))
if
cls
.
__name__
in
TASK_CLASS_NAMES
:
raise
ValueError
(
'Cannot register task with duplicate class name ({})'
.
format
(
cls
.
__name__
))
TASK_REGISTRY
[
name
]
=
cls
TASK_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_task_cls
# automatically import any Python files in the tasks/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.tasks.'
+
module
)
implementations/pytorch/fairseq/tasks/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
implementations/pytorch/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc
0 → 100644
View file @
9e8a8c05
File added
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment