Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Uni-Fold_pytorch
Commits
a1c29028
Commit
a1c29028
authored
Apr 17, 2023
by
zhangqha
Browse files
update uni-fold
parents
Pipeline
#183
canceled with stages
Changes
312
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3677 additions
and
0 deletions
+3677
-0
Uni-Core-main/unicore/optim/fp16_optimizer.py
Uni-Core-main/unicore/optim/fp16_optimizer.py
+326
-0
Uni-Core-main/unicore/optim/fused_adam.py
Uni-Core-main/unicore/optim/fused_adam.py
+145
-0
Uni-Core-main/unicore/optim/lr_scheduler/__init__.py
Uni-Core-main/unicore/optim/lr_scheduler/__init__.py
+34
-0
Uni-Core-main/unicore/optim/lr_scheduler/cosine_lr_scheduler.py
...re-main/unicore/optim/lr_scheduler/cosine_lr_scheduler.py
+132
-0
Uni-Core-main/unicore/optim/lr_scheduler/exponential_decay_schedule.py
.../unicore/optim/lr_scheduler/exponential_decay_schedule.py
+50
-0
Uni-Core-main/unicore/optim/lr_scheduler/fixed_schedule.py
Uni-Core-main/unicore/optim/lr_scheduler/fixed_schedule.py
+69
-0
Uni-Core-main/unicore/optim/lr_scheduler/inverse_square_root_schedule.py
...nicore/optim/lr_scheduler/inverse_square_root_schedule.py
+77
-0
Uni-Core-main/unicore/optim/lr_scheduler/pass_through.py
Uni-Core-main/unicore/optim/lr_scheduler/pass_through.py
+32
-0
Uni-Core-main/unicore/optim/lr_scheduler/polynomial_decay_schedule.py
...n/unicore/optim/lr_scheduler/polynomial_decay_schedule.py
+79
-0
Uni-Core-main/unicore/optim/lr_scheduler/reduce_lr_on_plateau.py
...e-main/unicore/optim/lr_scheduler/reduce_lr_on_plateau.py
+116
-0
Uni-Core-main/unicore/optim/lr_scheduler/tri_stage_lr_scheduler.py
...main/unicore/optim/lr_scheduler/tri_stage_lr_scheduler.py
+177
-0
Uni-Core-main/unicore/optim/lr_scheduler/triangular_lr_scheduler.py
...ain/unicore/optim/lr_scheduler/triangular_lr_scheduler.py
+76
-0
Uni-Core-main/unicore/optim/lr_scheduler/unicore_lr_scheduler.py
...e-main/unicore/optim/lr_scheduler/unicore_lr_scheduler.py
+50
-0
Uni-Core-main/unicore/optim/sgd.py
Uni-Core-main/unicore/optim/sgd.py
+44
-0
Uni-Core-main/unicore/optim/unicore_optimizer.py
Uni-Core-main/unicore/optim/unicore_optimizer.py
+191
-0
Uni-Core-main/unicore/options.py
Uni-Core-main/unicore/options.py
+413
-0
Uni-Core-main/unicore/registry.py
Uni-Core-main/unicore/registry.py
+81
-0
Uni-Core-main/unicore/tasks/__init__.py
Uni-Core-main/unicore/tasks/__init__.py
+86
-0
Uni-Core-main/unicore/tasks/unicore_task.py
Uni-Core-main/unicore/tasks/unicore_task.py
+331
-0
Uni-Core-main/unicore/trainer.py
Uni-Core-main/unicore/trainer.py
+1168
-0
No files found.
Uni-Core-main/unicore/optim/fp16_optimizer.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
collections
import
defaultdict
import
torch
from
unicore
import
optim
from
unicore
import
utils
from
.dynamic_loss_scaler
import
DynamicLossScaler
def
check_param_device
(
params
):
if
len
(
params
)
<=
0
:
return
True
device
=
params
[
0
].
device
for
i
in
range
(
1
,
len
(
params
)):
assert
device
==
params
[
i
].
device
def
pad_numel
(
numel
,
multiplier
=
2
):
return
(
numel
+
multiplier
-
1
)
//
multiplier
*
multiplier
class
_FP16OptimizerMixin
(
object
):
def
__init__
(
self
,
args
,
**
kwargs
):
# forward __init__ call to the next class in mro(method resolution order)
super
().
__init__
(
args
,
**
kwargs
)
self
.
_multiply_factor
=
1.0
self
.
bf16_sr
=
getattr
(
args
,
"bf16_sr"
,
False
)
@
classmethod
def
build_fp32_params
(
cls
,
args
,
params
):
# create FP32 copy of parameters and grads
total_param_size
=
sum
([
p
.
data
.
numel
()
for
p
in
params
])
fp32_params
=
params
[
0
].
new
(
0
).
float
().
new
(
total_param_size
)
offset
=
0
for
p
in
params
:
numel
=
p
.
data
.
numel
()
fp32_params
[
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
offset
+=
numel
fp32_params
=
torch
.
nn
.
Parameter
(
fp32_params
)
fp32_params
.
grad
=
fp32_params
.
data
.
new
(
total_param_size
)
return
fp32_params
@
classmethod
def
flatten_fp16_parameters
(
cls
,
args
,
params
):
dtype_grouped_params
=
{}
for
p
in
params
:
if
p
.
dtype
not
in
dtype_grouped_params
:
dtype_grouped_params
[
p
.
dtype
]
=
[]
dtype_grouped_params
[
p
.
dtype
].
append
(
p
)
flatten_params
=
{}
for
dtype
in
dtype_grouped_params
:
cur_params
=
dtype_grouped_params
[
dtype
]
total_param_size
=
sum
(
pad_numel
(
p
.
data
.
numel
())
for
p
in
cur_params
)
flatten_params
[
dtype
]
=
(
cur_params
[
0
].
new
(
0
).
type
(
dtype
).
new
(
total_param_size
)
)
offset
=
0
for
p
in
cur_params
:
numel
=
p
.
data
.
numel
()
flatten_params
[
dtype
][
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
p
.
data
=
(
flatten_params
[
dtype
].
data
[
offset
:
offset
+
numel
].
view
(
*
p
.
shape
)
)
offset
+=
pad_numel
(
numel
)
flatten_params
[
dtype
]
=
torch
.
nn
.
Parameter
(
flatten_params
[
dtype
])
flatten_params
[
dtype
].
grad
=
flatten_params
[
dtype
].
data
.
new
(
total_param_size
)
offset
=
0
for
p
in
cur_params
:
numel
=
p
.
data
.
numel
()
p
.
grad
=
(
flatten_params
[
dtype
].
grad
[
offset
:
offset
+
numel
].
view
(
*
p
.
shape
)
)
offset
+=
pad_numel
(
numel
)
torch
.
cuda
.
empty_cache
()
return
list
(
flatten_params
.
values
())
def
state_dict
(
self
):
"""Return the optimizer's state dict."""
state_dict
=
self
.
fp32_optimizer
.
state_dict
()
if
self
.
scaler
is
not
None
:
state_dict
[
"loss_scale"
]
=
self
.
scaler
.
loss_scale
return
state_dict
def
load_state_dict
(
self
,
state_dict
,
optimizer_overrides
=
None
):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if
"loss_scale"
in
state_dict
and
self
.
scaler
is
not
None
:
self
.
scaler
.
loss_scale
=
state_dict
[
"loss_scale"
]
self
.
fp32_optimizer
.
load_state_dict
(
state_dict
,
optimizer_overrides
)
def
backward
(
self
,
loss
):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
if
self
.
scaler
is
not
None
:
loss
=
self
.
scaler
.
scale
(
loss
)
loss
.
backward
()
self
.
_needs_sync
=
True
def
_sync_fp16_grads_to_fp32
(
self
):
with
torch
.
no_grad
():
if
self
.
_needs_sync
:
offset
=
0
for
p
in
self
.
fp16_params
:
numel
=
p
.
numel
()
self
.
fp32_params
.
grad
.
data
[
offset
:
offset
+
numel
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
)
)
offset
+=
pad_numel
(
numel
)
self
.
_needs_sync
=
False
def
_add_fp16_grads_to_fp32
(
self
,
mul
=
0.0
):
with
torch
.
no_grad
():
offset
=
0
for
p
in
self
.
fp16_params
:
numel
=
p
.
numel
()
self
.
fp32_params
.
grad
.
data
[
offset
:
offset
+
numel
]
+=
mul
*
p
.
grad
.
data
.
float
().
view
(
-
1
)
p
.
grad
.
zero_
()
offset
+=
pad_numel
(
numel
)
self
.
_needs_sync
=
False
def
_sync_fp32_params_to_fp16
(
self
):
# copy FP32 params back into FP16 model
offset
=
0
for
p
in
self
.
fp16_params
:
numel
=
p
.
numel
()
u
=
self
.
fp32_params
.
data
[
offset
:
offset
+
numel
].
view_as
(
p
.
data
)
if
self
.
bf16_sr
and
p
.
dtype
==
torch
.
bfloat16
:
utils
.
fp32_to_bf16_sr
(
u
,
p
)
else
:
p
.
data
.
copy_
(
u
)
offset
+=
pad_numel
(
numel
)
def
_unscale_grads
(
self
):
self
.
_sync_fp16_grads_to_fp32
()
if
(
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor
# is 1.0). At the same time, we want to avoid the device-to-host
# transfer by comparing it to 1.0. Since _multiply_factor starts as
# a Python float, we roughly assume that if it's a tensor then it's
# probably not =1.0 anymore and we do the multiplication. Otherwise
# we can safely check the value without a D2H transfer.
torch
.
is_tensor
(
self
.
_multiply_factor
)
or
self
.
_multiply_factor
!=
1.0
):
self
.
fp32_optimizer
.
multiply_grads
(
self
.
_multiply_factor
)
self
.
_multiply_factor
=
1.0
def
multiply_grads
(
self
,
c
):
"""Multiplies grads by a constant ``c``."""
if
self
.
_needs_sync
:
self
.
_multiply_factor
*=
c
else
:
# gradients already synced to fp32 parameters, update it directly
self
.
fp32_optimizer
.
multiply_grads
(
c
)
def
per_sample_clip_grad_norm
(
self
,
max_norm
,
aggregate_norm_fn
=
None
):
"""Clips gradient norm."""
if
max_norm
<=
0.0
:
return
0.0
grad_norm
=
self
.
_multiply_factor
*
utils
.
clip_grad_norm_
(
self
.
fp16_params
,
0
,
aggregate_norm_fn
)
# grad_norm = 1.0
if
grad_norm
>
max_norm
>
0.0
:
clip_coef
=
max_norm
/
(
grad_norm
+
1e-6
)
else
:
clip_coef
=
1.0
self
.
_add_fp16_grads_to_fp32
(
mul
=
clip_coef
)
def
clip_grad_norm
(
self
,
max_norm
,
aggregate_norm_fn
=
None
):
"""Clips gradient norm and updates dynamic loss scaler."""
self
.
_sync_fp16_grads_to_fp32
()
grad_norm
=
self
.
_multiply_factor
*
self
.
fp32_optimizer
.
clip_grad_norm
(
0
,
aggregate_norm_fn
=
aggregate_norm_fn
,
)
if
self
.
scaler
is
not
None
:
if
grad_norm
>
max_norm
>
0.0
:
self
.
_multiply_factor
*=
max_norm
/
grad_norm
self
.
scaler
.
check_overflow
(
grad_norm
)
elif
max_norm
>
0.0
:
clip_coef
=
(
max_norm
/
(
grad_norm
+
1e-6
)).
clamp_
(
max
=
1
)
self
.
_multiply_factor
*=
clip_coef
return
grad_norm
def
step
(
self
,
closure
=
None
,
groups
=
None
):
"""Performs a single optimization step."""
self
.
_sync_fp16_grads_to_fp32
()
if
getattr
(
self
,
"supports_step_with_scale"
,
False
):
self
.
fp32_optimizer
.
step
(
closure
,
scale
=
(
1.0
/
self
.
_multiply_factor
),
groups
=
groups
)
else
:
self
.
_unscale_grads
()
self
.
fp32_optimizer
.
step
(
closure
,
groups
=
groups
)
if
self
.
scaler
is
not
None
:
self
.
scaler
.
update
()
self
.
_sync_fp32_params_to_fp16
()
def
zero_grad
(
self
):
"""Clears the gradients of all optimized parameters."""
for
p
in
self
.
fp16_params
:
p
.
grad
.
zero_
()
if
torch
.
is_tensor
(
self
.
fp32_params
):
self
.
fp32_params
.
grad
.
zero_
()
elif
isinstance
(
self
.
fp32_params
,
dict
):
for
fp32_params
in
self
.
fp32_params
.
values
():
fp32_params
.
grad
.
zero_
()
else
:
raise
RuntimeError
(
"self.fp32_params must be a tensor or dict"
)
self
.
_needs_sync
=
False
if
self
.
scaler
is
not
None
:
self
.
_multiply_factor
=
1.0
/
float
(
self
.
scaler
.
loss_scale
)
else
:
self
.
_multiply_factor
=
1.0
class
FP16Optimizer
(
_FP16OptimizerMixin
,
optim
.
UnicoreOptimizer
):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def
__init__
(
self
,
args
,
params
,
fp32_optimizer
,
fp32_params
,
**
kwargs
):
super
().
__init__
(
args
)
self
.
fp16_params
=
params
self
.
fp32_optimizer
=
fp32_optimizer
self
.
fp32_params
=
fp32_params
self
.
allreduce_fp32_grad
=
getattr
(
args
,
"allreduce_fp32_grad"
,
False
)
if
getattr
(
args
,
"fp16_scale_window"
,
None
)
is
None
:
if
len
(
args
.
update_freq
)
>
1
:
raise
ValueError
(
"--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule"
)
data_parallel_size
=
int
(
args
.
distributed_world_size
)
scale_window
=
int
(
2
**
14
/
data_parallel_size
/
args
.
update_freq
[
0
])
else
:
scale_window
=
args
.
fp16_scale_window
if
not
getattr
(
args
,
"bf16"
,
False
):
self
.
scaler
=
DynamicLossScaler
(
init_scale
=
args
.
fp16_init_scale
,
scale_window
=
scale_window
,
tolerance
=
args
.
fp16_scale_tolerance
,
threshold
=
args
.
threshold_loss_scale
,
min_loss_scale
=
args
.
min_loss_scale
,
)
else
:
# disable loss scaling for bfloat16
self
.
scaler
=
None
@
classmethod
def
build_optimizer
(
cls
,
args
,
params
,
**
kwargs
):
"""
Args:
args : unicore args
params (iterable): iterable of parameters to optimize
"""
flatten
=
not
getattr
(
args
,
"fp16_no_flatten_grads"
,
False
)
assert
flatten
check_param_device
(
params
)
params
=
cls
.
flatten_fp16_parameters
(
args
,
params
)
fp32_params
=
cls
.
build_fp32_params
(
args
,
params
)
fp32_optimizer
=
optim
.
build_optimizer
(
args
,
[
fp32_params
])
return
cls
(
args
,
params
,
fp32_optimizer
,
fp32_params
,
**
kwargs
)
@
property
def
optimizer
(
self
):
return
self
.
fp32_optimizer
.
optimizer
@
optimizer
.
setter
def
optimizer
(
self
,
optimizer
):
self
.
fp32_optimizer
.
optimizer
=
optimizer
@
property
def
lr_scheduler
(
self
):
return
getattr
(
self
.
fp32_optimizer
,
"lr_scheduler"
,
None
)
@
property
def
optimizer_config
(
self
):
return
self
.
fp32_optimizer
.
optimizer_config
def
get_lr
(
self
):
return
self
.
fp32_optimizer
.
get_lr
()
def
set_lr
(
self
,
lr
):
self
.
fp32_optimizer
.
set_lr
(
lr
)
def
all_reduce_grads
(
self
,
module
):
if
self
.
allreduce_fp32_grad
and
hasattr
(
module
,
"all_reduce_params"
):
self
.
_sync_fp16_grads_to_fp32
()
with
torch
.
no_grad
():
params
=
[
self
.
fp32_params
]
module
.
all_reduce_params
(
params
)
else
:
self
.
fp32_optimizer
.
all_reduce_grads
(
module
)
@
property
def
supports_flat_params
(
self
):
return
self
.
fp32_optimizer
.
supports_flat_params
Uni-Core-main/unicore/optim/fused_adam.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
def
get_fused_adam_class
():
try
:
global
unicore_fused_adam
import
importlib
unicore_fused_adam
=
importlib
.
import_module
(
"unicore_fused_adam"
)
return
FusedAdam
except
ImportError
:
pass
return
None
class
FusedAdam
(
torch
.
optim
.
Optimizer
):
"""
Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Compared to the original version in Apex, the unicore version casts grads
and params to FP32 internally to support ``--memory-efficient-fp16``.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the "update parameters" step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.
,
amsgrad
=
False
):
global
unicore_fused_adam
import
importlib
unicore_fused_adam
=
importlib
.
import_module
(
"unicore_fused_adam"
)
if
amsgrad
:
raise
RuntimeError
(
"FusedAdam does not support the AMSGrad variant."
)
defaults
=
{
"lr"
:
lr
,
"bias_correction"
:
bias_correction
,
"betas"
:
betas
,
"eps"
:
eps
,
"weight_decay"
:
weight_decay
,
}
super
().
__init__
(
params
,
defaults
)
@
property
def
supports_memory_efficient_fp16
(
self
):
return
True
@
property
def
supports_flat_params
(
self
):
return
True
@
property
def
supports_step_with_scale
(
self
):
return
True
def
step
(
self
,
closure
=
None
,
scale
=
1.
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
# compute combined scale factor for this group
combined_scale
=
scale
bias_correction
=
1
if
group
.
get
(
"bias_correction"
,
1
)
else
0
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
if
grad
.
is_sparse
:
raise
RuntimeError
(
"FusedAdam does not support sparse gradients, "
"please consider SparseAdam instead"
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
"step"
]
=
0
# Exponential moving average of gradient values
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
torch
.
float
)
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
torch
.
float
)
else
:
state
[
"exp_avg"
]
=
state
[
"exp_avg"
].
to
(
dtype
=
torch
.
float
)
state
[
"exp_avg_sq"
]
=
state
[
"exp_avg_sq"
].
to
(
dtype
=
torch
.
float
)
exp_avg
=
state
[
"exp_avg"
]
exp_avg_sq
=
state
[
"exp_avg_sq"
]
beta1
,
beta2
=
group
[
"betas"
]
state
[
"step"
]
+=
1
with
torch
.
cuda
.
device
(
p
.
device
):
unicore_fused_adam
.
adam
(
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
"lr"
],
beta1
,
beta2
,
group
[
"eps"
],
combined_scale
,
state
[
"step"
],
bias_correction
,
group
[
"weight_decay"
])
return
loss
Uni-Core-main/unicore/optim/lr_scheduler/__init__.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import
importlib
import
os
from
unicore
import
registry
from
unicore.optim.lr_scheduler.unicore_lr_scheduler
import
(
# noqa
UnicoreLRScheduler
,
)
(
build_lr_scheduler_
,
register_lr_scheduler
,
LR_SCHEDULER_REGISTRY
,
)
=
registry
.
setup_registry
(
"--lr-scheduler"
,
base_class
=
UnicoreLRScheduler
,
default
=
"fixed"
)
def
build_lr_scheduler
(
args
,
optimizer
,
total_train_steps
):
return
build_lr_scheduler_
(
args
,
optimizer
,
total_train_steps
)
# automatically import any Python files in the optim/lr_scheduler/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
file_name
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"unicore.optim.lr_scheduler."
+
file_name
)
Uni-Core-main/unicore/optim/lr_scheduler/cosine_lr_scheduler.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
collections.abc
import
Collection
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"cosine"
)
class
CosineLRSchedule
(
UnicoreLRScheduler
):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (``--warmup-init-lr``) until the configured
max learning rate (``--lr``).
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup::
lr = args.min_lr + 0.5*(args.lr - args.min_lr)*(1 + cos(t_curr / t_i))
where ``t_curr`` is current percentage of updates within the current period
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
"""
def
__init__
(
self
,
args
,
unicore_optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
unicore_optimizer
,
total_train_steps
)
if
isinstance
(
args
.
lr
,
Collection
)
and
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
"Cannot use a fixed learning rate schedule with cosine."
f
" Consider --lr-scheduler=fixed instead. (
{
args
.
lr
}
)"
)
self
.
max_lr
=
args
.
lr
[
0
]
if
isinstance
(
args
.
lr
,
Collection
)
else
args
.
lr
assert
(
self
.
max_lr
>
args
.
min_lr
),
f
"max_lr (=
{
args
.
lr
}
) must be more than min_lr (=
{
args
.
min_lr
}
)"
warmup_end_lr
=
self
.
max_lr
if
args
.
warmup_init_lr
<
0
:
args
.
warmup_init_lr
=
args
.
min_lr
self
.
t_mult
=
args
.
t_mult
self
.
period
=
args
.
lr_period_updates
if
self
.
period
<=
0
:
assert
(
args
.
max_update
>
0
),
"Either --max_update or --lr-period-updates must be set"
self
.
period
=
args
.
max_update
-
args
.
warmup_updates
if
args
.
warmup_updates
>
0
:
# linearly warmup for the first args.warmup_updates
self
.
lr_step
=
(
warmup_end_lr
-
args
.
warmup_init_lr
)
/
args
.
warmup_updates
else
:
self
.
lr_step
=
1
self
.
warmup_updates
=
args
.
warmup_updates
self
.
lr_shrink
=
args
.
lr_shrink
# initial learning rate
self
.
lr
=
args
.
warmup_init_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--warmup-init-lr'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate during warmup phase; default is args.lr'
)
parser
.
add_argument
(
'--max-lr'
,
type
=
float
,
metavar
=
'LR'
,
help
=
'max learning rate, must be more than args.lr'
)
parser
.
add_argument
(
'--t-mult'
,
default
=
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'factor to grow the length of each period'
)
parser
.
add_argument
(
'--lr-period-updates'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial number of updates per period'
)
parser
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'shrink factor for annealing'
)
# fmt: on
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
num_updates
<
self
.
args
.
warmup_updates
:
self
.
lr
=
self
.
args
.
warmup_init_lr
+
num_updates
*
self
.
lr_step
else
:
curr_updates
=
num_updates
-
self
.
args
.
warmup_updates
if
self
.
t_mult
!=
1
:
i
=
math
.
floor
(
math
.
log
(
1
-
curr_updates
/
self
.
period
*
(
1
-
self
.
t_mult
),
self
.
t_mult
)
)
t_i
=
self
.
t_mult
**
i
*
self
.
period
t_curr
=
(
curr_updates
-
(
1
-
self
.
t_mult
**
i
)
/
(
1
-
self
.
t_mult
)
*
self
.
period
)
else
:
i
=
math
.
floor
(
curr_updates
/
self
.
period
)
t_i
=
self
.
period
t_curr
=
curr_updates
-
(
self
.
period
*
i
)
lr_shrink
=
self
.
lr_shrink
**
i
min_lr
=
self
.
args
.
min_lr
*
lr_shrink
max_lr
=
self
.
max_lr
*
lr_shrink
self
.
lr
=
min_lr
+
0.5
*
(
max_lr
-
min_lr
)
*
(
1
+
math
.
cos
(
math
.
pi
*
t_curr
/
t_i
)
)
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
Uni-Core-main/unicore/optim/lr_scheduler/exponential_decay_schedule.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"exponential_decay"
)
class
ExponentialDecayLRSchedule
(
UnicoreLRScheduler
):
"""Decay the LR on a fixed schedule."""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
self
.
warmup_updates
=
args
.
warmup_updates
self
.
lr
=
args
.
lr
[
0
]
if
self
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.0
/
self
.
warmup_updates
else
:
self
.
warmup_factor
=
1.0
self
.
decay_ratio
=
args
.
decay_ratio
self
.
decay_steps
=
args
.
decay_steps
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
self
.
stair_decay
=
getattr
(
args
,
"stair_decay"
,
False
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--warmup-updates'
,
default
=
1000
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--decay-ratio'
,
default
=
0.95
,
type
=
float
)
parser
.
add_argument
(
'--decay-steps'
,
default
=
500
,
type
=
int
)
parser
.
add_argument
(
'--stair-decay'
,
action
=
"store_true"
)
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
warmup_updates
>
0
and
num_updates
<=
self
.
warmup_updates
:
self
.
warmup_factor
=
num_updates
/
float
(
self
.
warmup_updates
)
lr
=
self
.
warmup_factor
*
self
.
lr
else
:
if
self
.
stair_decay
:
step
=
num_updates
lr
=
self
.
lr
*
float
(
self
.
decay_ratio
**
(
int
(
step
//
self
.
decay_steps
)))
else
:
step
=
num_updates
-
self
.
warmup_updates
lr
=
self
.
lr
*
float
(
self
.
decay_ratio
**
(
float
(
step
/
self
.
decay_steps
)))
self
.
optimizer
.
set_lr
(
lr
)
return
self
.
optimizer
.
get_lr
()
Uni-Core-main/unicore/optim/lr_scheduler/fixed_schedule.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"fixed"
)
class
FixedLRSchedule
(
UnicoreLRScheduler
):
"""Decay the LR on a fixed schedule."""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
self
.
lr
=
args
.
lr
[
0
]
if
args
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.0
/
args
.
warmup_updates
else
:
self
.
warmup_factor
=
1
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser
.
add_argument
(
'--force-anneal'
,
'--fa'
,
type
=
int
,
metavar
=
'N'
,
help
=
'force annealing at specified epoch'
)
parser
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'shrink factor for annealing, lr_new = (lr * lr_shrink)'
)
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
# fmt: on
def
state_dict
(
self
):
return
{
"lr"
:
self
.
lr
}
def
load_state_dict
(
self
,
state_dict
):
if
"lr"
in
state_dict
:
self
.
lr
=
state_dict
[
"lr"
]
def
get_next_lr
(
self
,
epoch
):
lrs
=
self
.
args
.
lr
if
self
.
args
.
force_anneal
is
None
or
epoch
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
next_lr
=
lrs
[
min
(
epoch
-
1
,
len
(
lrs
)
-
1
)]
else
:
# annneal based on lr_shrink
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lr_shrink
**
(
epoch
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
def
step_begin_epoch
(
self
,
epoch
):
"""Update the learning rate at the beginning of the given epoch."""
self
.
lr
=
self
.
get_next_lr
(
epoch
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
args
.
warmup_updates
>
0
and
num_updates
<
self
.
args
.
warmup_updates
:
self
.
warmup_factor
=
(
num_updates
+
1
)
/
float
(
self
.
args
.
warmup_updates
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
else
:
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
Uni-Core-main/unicore/optim/lr_scheduler/inverse_square_root_schedule.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
collections.abc
import
Collection
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"inverse_sqrt"
)
class
InverseSquareRootSchedule
(
UnicoreLRScheduler
):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup::
decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
"""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
if
isinstance
(
args
.
lr
,
Collection
)
and
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
"Cannot use a fixed learning rate schedule with inverse_sqrt."
" Consider --lr-scheduler=fixed instead."
)
warmup_end_lr
=
args
.
lr
[
0
]
if
isinstance
(
args
.
lr
,
Collection
)
else
args
.
lr
if
args
.
warmup_init_lr
<
0
:
args
.
warmup_init_lr
=
0
if
args
.
warmup_updates
>
0
else
warmup_end_lr
# linearly warmup for the first args.warmup_updates
self
.
lr_step
=
(
warmup_end_lr
-
args
.
warmup_init_lr
)
/
args
.
warmup_updates
# then, decay prop. to the inverse square root of the update number
self
.
decay_factor
=
warmup_end_lr
*
args
.
warmup_updates
**
0.5
# initial learning rate
self
.
lr
=
args
.
warmup_init_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser
.
add_argument
(
'--warmup-updates'
,
default
=
4000
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--warmup-init-lr'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate during warmup phase; default is args.lr'
)
# fmt: on
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
num_updates
<
self
.
args
.
warmup_updates
:
self
.
lr
=
self
.
args
.
warmup_init_lr
+
num_updates
*
self
.
lr_step
else
:
self
.
lr
=
self
.
decay_factor
*
num_updates
**
-
0.5
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
Uni-Core-main/unicore/optim/lr_scheduler/pass_through.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"pass_through"
)
class
PassThroughScheduleSchedule
(
UnicoreLRScheduler
):
"""Delegate lr scheduling to the optimizer."""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
assert
(
hasattr
(
optimizer
,
"lr_scheduler"
)
and
optimizer
.
lr_scheduler
is
not
None
),
"Pass-through schedule can only be used with optimizers with their own schedulers"
def
state_dict
(
self
):
return
self
.
optimizer
.
lr_scheduler
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
self
.
optimizer
.
lr_scheduler
.
load_state_dict
(
state_dict
)
def
step_begin_epoch
(
self
,
epoch
):
"""Update the learning rate at the beginning of the given epoch."""
return
self
.
optimizer
.
lr_scheduler
.
step_begin_epoch
(
epoch
)
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
return
self
.
optimizer
.
lr_scheduler
.
step_update
(
num_updates
)
Uni-Core-main/unicore/optim/lr_scheduler/polynomial_decay_schedule.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"polynomial_decay"
)
class
PolynomialDecayLRSchedule
(
UnicoreLRScheduler
):
"""Decay the LR on a fixed schedule."""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
if
self
.
args
.
warmup_ratio
>
0
:
# if warmup_ratio > 0, use external train steps
assert
total_train_steps
is
not
None
self
.
warmup_updates
=
int
(
self
.
args
.
warmup_ratio
*
total_train_steps
)
self
.
total_num_update
=
total_train_steps
else
:
assert
args
.
total_num_update
>
0
self
.
warmup_updates
=
args
.
warmup_updates
self
.
total_num_update
=
args
.
total_num_update
self
.
lr
=
args
.
lr
[
0
]
if
self
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.0
/
self
.
warmup_updates
else
:
self
.
warmup_factor
=
1
self
.
end_learning_rate
=
args
.
end_learning_rate
self
.
power
=
args
.
power
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--force-anneal'
,
'--fa'
,
type
=
int
,
metavar
=
'N'
,
help
=
'force annealing at specified epoch'
)
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--warmup-ratio'
,
default
=-
1.0
,
type
=
float
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N-percent updates'
)
parser
.
add_argument
(
'--end-learning-rate'
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
'--power'
,
default
=
1.0
,
type
=
float
)
parser
.
add_argument
(
'--total-num-update'
,
default
=
1000000
,
type
=
int
)
def
get_next_lr
(
self
,
epoch
):
lrs
=
self
.
args
.
lr
if
self
.
args
.
force_anneal
is
None
or
epoch
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
next_lr
=
lrs
[
min
(
epoch
,
len
(
lrs
)
-
1
)]
else
:
# annneal based on lr_shrink
next_lr
=
self
.
optimizer
.
get_lr
()
return
next_lr
def
step_begin_epoch
(
self
,
epoch
):
"""Update the learning rate at the beginning of the given epoch."""
self
.
lr
=
self
.
get_next_lr
(
epoch
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
warmup_updates
>
0
and
num_updates
<=
self
.
warmup_updates
:
self
.
warmup_factor
=
num_updates
/
float
(
self
.
warmup_updates
)
lr
=
self
.
warmup_factor
*
self
.
lr
elif
num_updates
>=
self
.
total_num_update
:
lr
=
self
.
end_learning_rate
else
:
warmup
=
self
.
warmup_updates
lr_range
=
self
.
lr
-
self
.
end_learning_rate
pct_remaining
=
1
-
(
num_updates
-
warmup
)
/
(
self
.
total_num_update
-
warmup
)
lr
=
lr_range
*
pct_remaining
**
(
self
.
power
)
+
self
.
end_learning_rate
self
.
optimizer
.
set_lr
(
lr
)
return
self
.
optimizer
.
get_lr
()
Uni-Core-main/unicore/optim/lr_scheduler/reduce_lr_on_plateau.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
List
import
torch.optim.lr_scheduler
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"reduce_lr_on_plateau"
)
class
ReduceLROnPlateauLRSchedule
(
UnicoreLRScheduler
):
"""
Decay the LR by a factor every time the validation loss plateaus.
Also comes with optional warmup phase, where we linearly increase
the learning rate from some initial learning rate
(``--warmup-init-lr``) until the configured learning rate
(``--lr``). Thereafter the lr is adjusted according to original
reduce_on_plateau scheme.
During warmup::
lrs = torch.linspace(
args.warmup_init_lr, args.lr, args.warmup_updates
)
lr = lrs[update_num]
"""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
"Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
" Consider --lr-scheduler=fixed instead."
)
self
.
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
self
.
optimizer
.
optimizer
,
patience
=
args
.
lr_patience
,
factor
=
args
.
lr_shrink
,
mode
=
"max"
if
args
.
maximize_best_checkpoint_metric
else
"min"
,
threshold
=
args
.
lr_threshold
,
)
warmup_end_lr
=
args
.
lr
[
0
]
# if no warm up, sets initial lr to be args.lr[0]
if
args
.
warmup_init_lr
<
0
:
args
.
warmup_init_lr
=
0
if
args
.
warmup_updates
>
0
else
warmup_end_lr
# linearly warmup for the first args.warmup_updates
if
args
.
warmup_updates
>
0
:
self
.
lr_step
=
(
warmup_end_lr
-
args
.
warmup_init_lr
)
/
args
.
warmup_updates
# this flag is either set from arg when no warm up, or set by
# step_update() when warmup finishes
self
.
warmup_end
=
True
if
args
.
warmup_updates
<=
0
else
False
# initial learning rate
# this self.lr is used only during init and/or warm up period
self
.
lr
=
args
.
warmup_init_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'shrink factor for annealing, lr_new = (lr * lr_shrink)'
)
parser
.
add_argument
(
'--lr-threshold'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'LT'
,
help
=
'Threshold for measuring the new optimum,
\
to only focus on significant changes'
)
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--warmup-init-lr'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate during warmup phase; default is args.lr'
)
# fmt: on
def
state_dict
(
self
):
"""Return the LR scheduler state dict."""
return
{
"best"
:
self
.
lr_scheduler
.
best
,
"last_epoch"
:
self
.
lr_scheduler
.
last_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
"""Load an LR scheduler state dict."""
self
.
lr_scheduler
.
best
=
state_dict
[
"best"
]
if
"last_epoch"
in
state_dict
:
self
.
lr_scheduler
.
last_epoch
=
state_dict
[
"last_epoch"
]
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""
Update the learning rate at the end of the given epoch if warmup
finishes otherwise no update of lr on epoch boundaries
"""
if
val_loss
is
not
None
and
self
.
warmup_end
is
True
:
self
.
lr_scheduler
.
step
(
val_loss
)
else
:
self
.
lr_scheduler
.
last_epoch
=
epoch
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""
Update the learning rate after each update."""
# if there is warmup
if
self
.
args
.
warmup_updates
>
0
:
if
num_updates
<=
self
.
args
.
warmup_updates
:
self
.
lr
=
self
.
args
.
warmup_init_lr
+
num_updates
*
self
.
lr_step
self
.
optimizer
.
set_lr
(
self
.
lr
)
else
:
if
self
.
warmup_end
is
False
:
self
.
warmup_end
=
True
# else do nothing
return
self
.
optimizer
.
get_lr
()
Uni-Core-main/unicore/optim/lr_scheduler/tri_stage_lr_scheduler.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"tri_stage"
)
class
TriStageLRSchedule
(
UnicoreLRScheduler
):
"""Tristage learning rate schedulr
Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
Similar to inverse_squre_root scheduler, but tri_stage learning rate employs
three stages LR scheduling:
- warmup stage, starting from `lr` * `init_lr_scale`, linearly
increased to `lr` in `warmup_steps` iterations
- hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps`
iterations
- decay stage, after hold stage, decay LR exponetially to
`lr` * `final_lr_scale` in `decay_steps`;
after that LR is keep as `final_lr_scale` * `lr`
During warmup::
init_lr = args.init_lr_scale * args.lr
lrs = torch.linspace(init_lr, args.lr, args.warmup_steps)
lr = lrs[update_num]
During hold::
lr = args.lr
During decay::
decay_factor = - math.log(args.final_lr_scale) / args.decay_steps
lr = args.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor)
After that::
lr = args.lr * args.final_lr_scale
"""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
"Cannot use a fixed learning rate schedule with tri-stage lr."
" Consider --lr-scheduler=fixed instead."
)
# calculate LR at each point
self
.
peak_lr
=
args
.
lr
[
0
]
self
.
init_lr
=
args
.
init_lr_scale
*
args
.
lr
[
0
]
self
.
final_lr
=
args
.
final_lr_scale
*
args
.
lr
[
0
]
if
args
.
phase_ratio
is
not
None
:
assert
args
.
max_update
>
0
assert
sum
(
args
.
phase_ratio
)
==
1
,
"phase ratios must add up to 1"
self
.
warmup_steps
=
int
(
args
.
max_update
*
args
.
phase_ratio
[
0
])
self
.
hold_steps
=
int
(
args
.
max_update
*
args
.
phase_ratio
[
1
])
self
.
decay_steps
=
int
(
args
.
max_update
*
args
.
phase_ratio
[
2
])
else
:
self
.
warmup_steps
=
args
.
warmup_steps
self
.
hold_steps
=
args
.
hold_steps
self
.
decay_steps
=
args
.
decay_steps
assert
(
self
.
warmup_steps
+
self
.
hold_steps
+
self
.
decay_steps
>
0
),
"please specify steps or phase_ratio"
self
.
warmup_rate
=
(
(
self
.
peak_lr
-
self
.
init_lr
)
/
self
.
warmup_steps
if
self
.
warmup_steps
!=
0
else
0
)
self
.
decay_factor
=
-
math
.
log
(
args
.
final_lr_scale
)
/
self
.
decay_steps
# initial learning rate
self
.
lr
=
self
.
init_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser
.
add_argument
(
'--warmup-steps'
,
default
=
4000
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--hold-steps'
,
default
=
20000
,
type
=
int
,
metavar
=
'N'
,
help
=
'steps in hold stage.'
)
parser
.
add_argument
(
'--decay-steps'
,
default
=
60000
,
type
=
int
,
metavar
=
'N'
,
help
=
'steps in decay stages'
)
parser
.
add_argument
(
'--init-lr-scale'
,
default
=
0.01
,
type
=
float
,
help
=
"""
initial learning rate scale during warmup phase; default is 0.01"""
)
parser
.
add_argument
(
'--final-lr-scale'
,
default
=
0.01
,
type
=
float
,
help
=
"final learning rate scale; default to 0.01"
)
# fmt: on
def
_decide_stage
(
self
,
update_step
):
"""
return stage, and the corresponding steps within the current stage
"""
if
update_step
<
self
.
warmup_steps
:
# warmup state
return
0
,
update_step
offset
=
self
.
warmup_steps
if
update_step
<
offset
+
self
.
hold_steps
:
# hold stage
return
1
,
update_step
-
offset
offset
+=
self
.
hold_steps
if
update_step
<=
offset
+
self
.
decay_steps
:
# decay stage
return
2
,
update_step
-
offset
offset
+=
self
.
decay_steps
# still here ? constant lr stage
return
3
,
update_step
-
offset
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
stage
,
steps_in_stage
=
self
.
_decide_stage
(
num_updates
)
if
stage
==
0
:
self
.
lr
=
self
.
init_lr
+
self
.
warmup_rate
*
steps_in_stage
elif
stage
==
1
:
self
.
lr
=
self
.
peak_lr
elif
stage
==
2
:
self
.
lr
=
self
.
peak_lr
*
math
.
exp
(
-
self
.
decay_factor
*
steps_in_stage
)
elif
stage
==
3
:
self
.
lr
=
self
.
final_lr
else
:
raise
ValueError
(
"Undefined stage"
)
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
Uni-Core-main/unicore/optim/lr_scheduler/triangular_lr_scheduler.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
typing
import
List
from
unicore.optim.lr_scheduler
import
UnicoreLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
"triangular"
)
class
TriangularLRSchedule
(
UnicoreLRScheduler
):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details.
"""
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
(
args
,
optimizer
,
total_train_steps
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
"Cannot use a fixed learning rate schedule with triangular."
" Consider --lr-scheduler=fixed instead."
)
lr
=
args
.
lr
[
0
]
assert
args
.
max_lr
>
lr
,
"max_lr must be more than lr"
self
.
min_lr
=
lr
self
.
max_lr
=
args
.
max_lr
self
.
stepsize
=
args
.
lr_period_updates
//
2
self
.
lr_shrink
=
args
.
lr_shrink
self
.
shrink_min
=
args
.
shrink_min
# initial learning rate
self
.
lr
=
self
.
min_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser
.
add_argument
(
'--max-lr'
,
required
=
True
,
type
=
float
,
metavar
=
'LR'
,
help
=
'max learning rate, must be more than args.lr'
)
parser
.
add_argument
(
'--lr-period-updates'
,
default
=
5000
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial number of updates per period (cycle length)'
)
parser
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'shrink factor for annealing'
)
parser
.
add_argument
(
'--shrink-min'
,
action
=
'store_true'
,
help
=
'if set, also shrinks min lr'
)
# fmt: on
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
cycle
=
math
.
floor
(
num_updates
/
(
2
*
self
.
stepsize
))
lr_shrink
=
self
.
lr_shrink
**
cycle
max_lr
=
self
.
max_lr
*
lr_shrink
if
self
.
shrink_min
:
min_lr
=
self
.
min_lr
*
lr_shrink
else
:
min_lr
=
self
.
min_lr
x
=
abs
(
num_updates
/
self
.
stepsize
-
2
*
(
cycle
+
1
)
+
1
)
self
.
lr
=
min_lr
+
(
max_lr
-
min_lr
)
*
max
(
0
,
(
1
-
x
))
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
Uni-Core-main/unicore/optim/lr_scheduler/unicore_lr_scheduler.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
argparse
import
Namespace
from
unicore.optim
import
UnicoreOptimizer
class
UnicoreLRScheduler
(
object
):
def
__init__
(
self
,
args
,
optimizer
,
total_train_steps
):
super
().
__init__
()
if
optimizer
is
not
None
and
not
isinstance
(
optimizer
,
UnicoreOptimizer
):
raise
ValueError
(
"optimizer must be an instance of UnicoreOptimizer"
)
self
.
args
=
args
self
.
optimizer
=
optimizer
self
.
total_train_steps
=
total_train_steps
self
.
best
=
None
@
classmethod
def
add_args
(
cls
,
parser
):
"""Add arguments to the parser for this LR scheduler."""
pass
def
state_dict
(
self
):
"""Return the LR scheduler state dict."""
return
{
"best"
:
self
.
best
}
def
load_state_dict
(
self
,
state_dict
):
"""Load an LR scheduler state dict."""
self
.
best
=
state_dict
[
"best"
]
def
step_begin_epoch
(
self
,
epoch
):
"""Update the learning rate at the beginning of the given epoch."""
pass
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
if
val_loss
is
not
None
:
if
self
.
best
is
None
:
self
.
best
=
val_loss
else
:
self
.
best
=
min
(
self
.
best
,
val_loss
)
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
return
self
.
optimizer
.
get_lr
()
Uni-Core-main/unicore/optim/sgd.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch.optim
from
.
import
UnicoreOptimizer
,
register_optimizer
@
register_optimizer
(
"sgd"
)
class
SGD
(
UnicoreOptimizer
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
(
args
)
self
.
_optimizer
=
torch
.
optim
.
SGD
(
params
,
**
self
.
optimizer_config
)
@
staticmethod
def
add_args
(
parser
):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser
.
add_argument
(
'--momentum'
,
default
=
0.0
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum factor'
)
parser
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
0.0
,
type
=
float
,
metavar
=
'WD'
,
help
=
'weight decay'
)
# fmt: on
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return
{
"lr"
:
self
.
args
.
lr
[
0
],
"momentum"
:
self
.
args
.
momentum
,
"weight_decay"
:
self
.
args
.
weight_decay
,
}
@
property
def
supports_flat_params
(
self
):
return
True
Uni-Core-main/unicore/optim/unicore_optimizer.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
unicore
import
utils
class
UnicoreOptimizer
(
object
):
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
args
=
args
self
.
_grad_buffer
=
None
self
.
_need_sync_grad_buf
=
False
@
classmethod
def
add_args
(
cls
,
parser
):
"""Add optimizer-specific arguments to the parser."""
pass
@
property
def
optimizer
(
self
):
"""Return a torch.optim.optimizer.Optimizer instance."""
if
not
hasattr
(
self
,
"_optimizer"
):
raise
NotImplementedError
if
not
isinstance
(
self
.
_optimizer
,
torch
.
optim
.
Optimizer
):
raise
ValueError
(
"_optimizer must be an instance of torch.optim.Optimizer"
)
return
self
.
_optimizer
@
optimizer
.
setter
def
optimizer
(
self
,
optimizer
):
"""Reset optimizer instance."""
if
not
hasattr
(
self
,
"_optimizer"
):
raise
NotImplementedError
if
not
isinstance
(
self
.
_optimizer
,
torch
.
optim
.
Optimizer
):
raise
ValueError
(
"_optimizer must be an instance of torch.optim.Optimizer"
)
self
.
_optimizer
=
optimizer
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise
NotImplementedError
@
property
def
params
(
self
):
"""Return an iterable of the parameters held by the optimizer."""
for
param_group
in
self
.
param_groups
:
for
p
in
param_group
[
"params"
]:
yield
p
@
property
def
param_groups
(
self
):
return
self
.
optimizer
.
param_groups
def
__getstate__
(
self
):
return
self
.
_optimizer
.
__getstate__
()
def
get_lr
(
self
):
"""Return the current learning rate."""
return
self
.
param_groups
[
0
][
"lr"
]
def
set_lr
(
self
,
lr
):
"""Set the learning rate."""
for
param_group
in
self
.
param_groups
:
param_group
[
"lr"
]
=
lr
def
state_dict
(
self
):
"""Return the optimizer's state dict."""
return
self
.
optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
,
optimizer_overrides
=
None
):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self
.
optimizer
.
load_state_dict
(
state_dict
)
if
optimizer_overrides
is
not
None
and
len
(
optimizer_overrides
)
>
0
:
# override learning rate, momentum, etc. with latest values
for
group
in
self
.
param_groups
:
group
.
update
(
optimizer_overrides
)
def
backward
(
self
,
loss
):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
loss
.
backward
()
def
all_reduce_grads
(
self
,
module
):
"""Manually all-reduce gradients (if required)."""
self
.
__sync_grad_from_buf__
()
if
hasattr
(
module
,
"all_reduce_grads"
):
module
.
all_reduce_grads
()
def
multiply_grads
(
self
,
c
):
"""Multiplies grads by a constant *c*."""
for
p
in
self
.
params
:
if
p
.
grad
is
not
None
:
if
torch
.
is_tensor
(
c
):
c
=
c
.
to
(
p
.
grad
.
device
)
p
.
grad
.
data
.
mul_
(
c
)
def
per_sample_clip_grad_norm
(
self
,
max_norm
,
aggregate_norm_fn
=
None
):
"""Clips gradient norm."""
if
max_norm
<=
0.0
:
return
0.0
if
self
.
_grad_buffer
is
None
:
self
.
_grad_buffer
=
[
torch
.
zeros_like
(
g
)
for
g
in
self
.
params
]
gnorm
=
utils
.
clip_grad_norm_
(
self
.
params
,
max_norm
,
aggregate_norm_fn
)
for
i
,
p
in
enumerate
(
self
.
params
):
if
p
.
grad
is
None
:
continue
self
.
_grad_buffer
[
i
]
+=
p
.
grad
p
.
grad
=
None
self
.
_need_sync_grad_buf
=
True
return
gnorm
def
__sync_grad_from_buf__
(
self
):
if
self
.
_need_sync_grad_buf
:
assert
self
.
_grad_buffer
is
not
None
for
i
,
p
in
enumerate
(
self
.
params
):
p
.
grad
=
self
.
_grad_buffer
[
i
]
self
.
_need_sync_grad_buf
=
False
def
clip_grad_norm
(
self
,
max_norm
,
aggregate_norm_fn
=
None
):
"""Clips gradient norm."""
self
.
__sync_grad_from_buf__
()
return
utils
.
clip_grad_norm_
(
self
.
params
,
max_norm
,
aggregate_norm_fn
)
def
step
(
self
,
closure
=
None
,
scale
=
1.0
,
groups
=
None
):
"""Performs a single optimization step."""
self
.
__sync_grad_from_buf__
()
if
self
.
supports_step_with_scale
:
if
self
.
supports_groups
:
self
.
optimizer
.
step
(
closure
,
scale
=
scale
,
groups
=
groups
)
else
:
self
.
optimizer
.
step
(
closure
,
scale
=
scale
)
else
:
if
scale
!=
1.0
:
self
.
multiply_grads
(
1.0
/
scale
)
if
self
.
supports_groups
:
self
.
optimizer
.
step
(
closure
,
groups
=
groups
)
else
:
self
.
optimizer
.
step
(
closure
)
def
zero_grad
(
self
):
"""Clears the gradients of all optimized parameters."""
for
p
in
self
.
params
:
p
.
grad
=
None
self
.
optimizer
.
zero_grad
()
self
.
_need_sync_grad_buf
=
False
if
self
.
_grad_buffer
is
not
None
:
for
t
in
self
.
_grad_buffer
:
t
.
zero_
()
@
property
def
supports_memory_efficient_fp16
(
self
):
if
hasattr
(
self
.
optimizer
,
"supports_memory_efficient_fp16"
):
return
self
.
optimizer
.
supports_memory_efficient_fp16
return
False
@
property
def
supports_step_with_scale
(
self
):
if
hasattr
(
self
.
optimizer
,
"supports_step_with_scale"
):
return
self
.
optimizer
.
supports_step_with_scale
return
False
@
property
def
supports_groups
(
self
):
if
hasattr
(
self
.
optimizer
,
"supports_groups"
):
return
self
.
optimizer
.
supports_groups
return
False
@
property
def
supports_flat_params
(
self
):
"""
Whether the optimizer supports collapsing of the model
parameters/gradients into a single contiguous Tensor.
"""
if
hasattr
(
self
.
optimizer
,
"supports_flat_params"
):
return
self
.
optimizer
.
supports_flat_params
return
False
Uni-Core-main/unicore/options.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
torch
from
typing
import
Callable
,
List
,
Optional
# this import is for backward compatibility
from
unicore.utils
import
csv_str_list
,
eval_bool
,
eval_str_dict
,
eval_str_list
,
import_user_module
# noqa
def
get_training_parser
(
default_task
=
"translation"
):
parser
=
get_parser
(
"Trainer"
,
default_task
)
add_dataset_args
(
parser
,
train
=
True
)
add_distributed_training_args
(
parser
)
add_model_args
(
parser
)
add_optimization_args
(
parser
)
add_checkpoint_args
(
parser
)
return
parser
def
get_validation_parser
(
default_task
=
None
):
parser
=
get_parser
(
"Validation"
,
default_task
)
add_dataset_args
(
parser
,
train
=
True
)
add_distributed_training_args
(
parser
)
group
=
parser
.
add_argument_group
(
"Evaluation"
)
add_common_eval_args
(
group
)
return
parser
def
parse_args_and_arch
(
parser
:
argparse
.
ArgumentParser
,
input_args
:
List
[
str
]
=
None
,
parse_known
:
bool
=
False
,
suppress_defaults
:
bool
=
False
,
modify_parser
:
Optional
[
Callable
[[
argparse
.
ArgumentParser
],
None
]]
=
None
,
):
"""
Args:
parser (ArgumentParser): the parser
input_args (List[str]): strings to parse, defaults to sys.argv
parse_known (bool): only parse known arguments, similar to
`ArgumentParser.parse_known_args`
suppress_defaults (bool): parse while ignoring all default values
modify_parser (Optional[Callable[[ArgumentParser], None]]):
function to modify the parser, e.g., to set default values
"""
if
suppress_defaults
:
# Parse args without any default values. This requires us to parse
# twice, once to identify all the necessary task/model args, and a second
# time with all defaults set to None.
args
=
parse_args_and_arch
(
parser
,
input_args
=
input_args
,
parse_known
=
parse_known
,
suppress_defaults
=
False
,
)
suppressed_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
,
parents
=
[
parser
])
suppressed_parser
.
set_defaults
(
**
{
k
:
None
for
k
,
v
in
vars
(
args
).
items
()})
args
=
suppressed_parser
.
parse_args
(
input_args
)
return
argparse
.
Namespace
(
**
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
v
is
not
None
}
)
from
unicore.models
import
ARCH_MODEL_REGISTRY
,
ARCH_CONFIG_REGISTRY
,
MODEL_REGISTRY
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
,
allow_abbrev
=
False
)
usr_parser
.
add_argument
(
"--user-dir"
,
default
=
None
)
usr_args
,
_
=
usr_parser
.
parse_known_args
(
input_args
)
import_user_module
(
usr_args
)
if
modify_parser
is
not
None
:
modify_parser
(
parser
)
# The parser doesn't know about model/loss/optimizer-specific args, so
# we parse twice. First we parse the model/loss/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args
,
_
=
parser
.
parse_known_args
(
input_args
)
# Add model-specific args to parser.
if
hasattr
(
args
,
"arch"
):
model_specific_group
=
parser
.
add_argument_group
(
"Model-specific configuration"
,
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default
=
argparse
.
SUPPRESS
,
)
if
args
.
arch
in
ARCH_MODEL_REGISTRY
:
ARCH_MODEL_REGISTRY
[
args
.
arch
].
add_args
(
model_specific_group
)
elif
args
.
arch
in
MODEL_REGISTRY
:
MODEL_REGISTRY
[
args
.
arch
].
add_args
(
model_specific_group
)
else
:
raise
RuntimeError
()
if
hasattr
(
args
,
"task"
):
from
unicore.tasks
import
TASK_REGISTRY
TASK_REGISTRY
[
args
.
task
].
add_args
(
parser
)
# Add *-specific args to parser.
from
unicore.registry
import
REGISTRIES
for
registry_name
,
REGISTRY
in
REGISTRIES
.
items
():
choice
=
getattr
(
args
,
registry_name
,
None
)
if
choice
is
not
None
:
cls
=
REGISTRY
[
"registry"
][
choice
]
if
hasattr
(
cls
,
"add_args"
):
cls
.
add_args
(
parser
)
# Modify the parser a second time, since defaults may have been reset
if
modify_parser
is
not
None
:
modify_parser
(
parser
)
# Parse a second time.
if
parse_known
:
args
,
extra
=
parser
.
parse_known_args
(
input_args
)
else
:
args
=
parser
.
parse_args
(
input_args
)
extra
=
None
# Post-process args.
if
(
hasattr
(
args
,
"batch_size_valid"
)
and
args
.
batch_size_valid
is
None
)
or
not
hasattr
(
args
,
"batch_size_valid"
):
args
.
batch_size_valid
=
args
.
batch_size
args
.
bf16
=
getattr
(
args
,
"bf16"
,
False
)
if
getattr
(
args
,
"seed"
,
None
)
is
None
:
args
.
seed
=
1
# default seed for training
args
.
no_seed_provided
=
True
else
:
args
.
no_seed_provided
=
False
# Apply architecture configuration.
if
hasattr
(
args
,
"arch"
)
and
args
.
arch
in
ARCH_CONFIG_REGISTRY
:
ARCH_CONFIG_REGISTRY
[
args
.
arch
](
args
)
if
parse_known
:
return
args
,
extra
else
:
return
args
def
get_parser
(
desc
,
default_task
=
'test'
):
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
,
allow_abbrev
=
False
)
usr_parser
.
add_argument
(
'--user-dir'
,
default
=
None
)
usr_args
,
_
=
usr_parser
.
parse_known_args
()
import_user_module
(
usr_args
)
parser
=
argparse
.
ArgumentParser
(
allow_abbrev
=
False
)
# fmt: off
parser
.
add_argument
(
'--no-progress-bar'
,
action
=
'store_true'
,
help
=
'disable progress bar'
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'log progress every N batches (when progress bar is disabled)'
)
parser
.
add_argument
(
'--log-format'
,
default
=
None
,
help
=
'log format to use'
,
choices
=
[
'json'
,
'none'
,
'simple'
,
'tqdm'
])
parser
.
add_argument
(
'--tensorboard-logdir'
,
metavar
=
'DIR'
,
default
=
''
,
help
=
'path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)'
)
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
parser
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'use CPU instead of CUDA'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'use FP16'
)
parser
.
add_argument
(
'--bf16'
,
action
=
'store_true'
,
help
=
'use BF16'
)
parser
.
add_argument
(
'--bf16-sr'
,
action
=
'store_true'
,
help
=
'use stachostic rounding for bf16'
)
parser
.
add_argument
(
'--allreduce-fp32-grad'
,
action
=
'store_true'
,
help
=
'use fp32-grads in fp16/bf16 mode. --ddp-backend should be no_c10d'
)
parser
.
add_argument
(
'--fp16-no-flatten-grads'
,
action
=
'store_true'
,
help
=
"don't flatten FP16 grads tensor"
)
parser
.
add_argument
(
'--fp16-init-scale'
,
default
=
2
**
7
,
type
=
int
,
help
=
'default FP16 loss scale'
)
parser
.
add_argument
(
'--fp16-scale-window'
,
type
=
int
,
help
=
'number of updates before increasing loss scale'
)
parser
.
add_argument
(
'--fp16-scale-tolerance'
,
default
=
0.0
,
type
=
float
,
help
=
'pct of updates that can overflow before decreasing the loss scale'
)
parser
.
add_argument
(
'--min-loss-scale'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'D'
,
help
=
'minimum FP16 loss scale, after which training is stopped'
)
parser
.
add_argument
(
'--threshold-loss-scale'
,
type
=
float
,
help
=
'threshold FP16 loss scale from below'
)
parser
.
add_argument
(
'--user-dir'
,
default
=
None
,
help
=
'path to a python module containing custom extensions (tasks and/or architectures)'
)
parser
.
add_argument
(
'--empty-cache-freq'
,
default
=
0
,
type
=
int
,
help
=
'how often to clear the PyTorch CUDA cache (0 to disable)'
)
parser
.
add_argument
(
'--all-gather-list-size'
,
default
=
16384
,
type
=
int
,
help
=
'number of bytes reserved for gathering stats from workers'
)
parser
.
add_argument
(
'--suppress-crashes'
,
action
=
'store_true'
,
help
=
"suppress crashes when training with the entry point so that the "
"main method can return a value (useful for sweeps)"
)
parser
.
add_argument
(
'--profile'
,
action
=
'store_true'
,
help
=
"enable autograd profiler emit_nvtx"
)
parser
.
add_argument
(
'--ema-decay'
,
default
=-
1.0
,
type
=
float
,
help
=
"enable moving average for model weights"
)
from
unicore.registry
import
REGISTRIES
for
registry_name
,
REGISTRY
in
REGISTRIES
.
items
():
parser
.
add_argument
(
'--'
+
registry_name
.
replace
(
'_'
,
'-'
),
default
=
REGISTRY
[
'default'
],
choices
=
REGISTRY
[
'registry'
].
keys
(),
)
# Task definitions can be found under unicore/tasks/
from
unicore.tasks
import
TASK_REGISTRY
parser
.
add_argument
(
'--task'
,
metavar
=
'TASK'
,
default
=
default_task
,
choices
=
TASK_REGISTRY
.
keys
(),
help
=
'task'
)
# fmt: on
return
parser
def
add_dataset_args
(
parser
,
train
=
False
,
gen
=
False
):
group
=
parser
.
add_argument_group
(
'Dataset and data loading'
)
# fmt: off
group
.
add_argument
(
'--num-workers'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'how many subprocesses to use for data loading'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'ignore too long or too short lines in valid and test set'
)
group
.
add_argument
(
'--batch-size'
,
'--max-sentences'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
group
.
add_argument
(
'--required-batch-size-multiple'
,
default
=
8
,
type
=
int
,
metavar
=
'N'
,
help
=
'batch size will be a multiplier of this value'
)
group
.
add_argument
(
'--data-buffer-size'
,
default
=
10
,
type
=
int
,
help
=
'Number of batches to preload'
)
group
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
,
'train.small'
],
help
=
'data subset to use for training (train, valid, test)'
)
group
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)'
)
group
.
add_argument
(
'--validate-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'validate every N epochs'
)
group
.
add_argument
(
'--validate-interval-updates'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'validate every N updates'
)
group
.
add_argument
(
'--validate-after-updates'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'dont validate until reaching this many updates'
)
group
.
add_argument
(
'--fixed-validation-seed'
,
default
=
None
,
type
=
int
,
metavar
=
'N'
,
help
=
'specified random seed for validation'
)
group
.
add_argument
(
'--disable-validation'
,
action
=
'store_true'
,
help
=
'disable validation'
)
group
.
add_argument
(
'--batch-size-valid'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a validation batch'
' (defaults to --max-sentences)'
)
group
.
add_argument
(
'--max-valid-steps'
,
type
=
int
,
metavar
=
'N'
,
help
=
'How many batches to evaluate'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'don
\'
t shuffle batches for first N epochs'
)
# fmt: on
return
group
def
add_distributed_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Distributed training'
)
# fmt: off
group
.
add_argument
(
'--distributed-world-size'
,
type
=
int
,
metavar
=
'N'
,
default
=
max
(
1
,
torch
.
cuda
.
device_count
()),
help
=
'total number of GPUs across all nodes (default: all visible GPUs)'
)
group
.
add_argument
(
'--distributed-rank'
,
default
=
0
,
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
type
=
str
,
help
=
'distributed backend'
)
group
.
add_argument
(
'--distributed-init-method'
,
default
=
None
,
type
=
str
,
help
=
'typically tcp://hostname:port that will be used to '
'establish initial connetion'
)
group
.
add_argument
(
'--distributed-port'
,
default
=-
1
,
type
=
int
,
help
=
'port number (not required if using --distributed-init-method)'
)
group
.
add_argument
(
'--device-id'
,
'--local_rank'
,
default
=
0
,
type
=
int
,
help
=
'which GPU to use (usually configured automatically)'
)
group
.
add_argument
(
'--distributed-no-spawn'
,
action
=
'store_true'
,
help
=
'do not spawn multiple processes even if multiple GPUs are visible'
)
group
.
add_argument
(
'--ddp-backend'
,
default
=
'c10d'
,
type
=
str
,
choices
=
[
'c10d'
,
'apex'
,
'no_c10d'
],
help
=
'DistributedDataParallel backend'
)
group
.
add_argument
(
'--bucket-cap-mb'
,
default
=
25
,
type
=
int
,
metavar
=
'MB'
,
help
=
'bucket size for reduction'
)
group
.
add_argument
(
'--fix-batches-to-gpus'
,
action
=
'store_true'
,
help
=
'don
\'
t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of '
're-reading the data'
)
group
.
add_argument
(
'--find-unused-parameters'
,
default
=
False
,
action
=
'store_true'
,
help
=
'disable unused parameter detection (not applicable to '
'no_c10d ddp-backend'
)
group
.
add_argument
(
'--fast-stat-sync'
,
default
=
False
,
action
=
'store_true'
,
help
=
'Enable fast sync of stats between nodes, this hardcodes to '
'sync only some default stats from logging_output.'
)
group
.
add_argument
(
'--broadcast-buffers'
,
default
=
False
,
action
=
'store_true'
,
help
=
"Copy non-trainable parameters between GPUs, such as "
"batchnorm population statistics"
)
group
.
add_argument
(
'--nprocs-per-node'
,
default
=
max
(
1
,
torch
.
cuda
.
device_count
()),
type
=
int
,
help
=
"number of GPUs in each node. An allreduce operation across GPUs in "
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
"and gossip across different nodes"
)
# fmt: on
return
group
def
add_optimization_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Optimization'
)
# fmt: off
group
.
add_argument
(
'--max-epoch'
,
'--me'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified epoch'
)
group
.
add_argument
(
'--max-update'
,
'--mu'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified update'
)
group
.
add_argument
(
'--stop-time-hours'
,
default
=
0
,
type
=
float
,
help
=
"force stop training after specified cumulative time (if >0)"
)
group
.
add_argument
(
'--clip-norm'
,
default
=
0
,
type
=
float
,
metavar
=
'NORM'
,
help
=
'clip threshold of gradients'
)
group
.
add_argument
(
'--per-sample-clip-norm'
,
default
=
0
,
type
=
float
,
metavar
=
'PNORM'
,
help
=
'clip threshold of gradients, before gradient sync over workers. In fp16/bf16 mode, --fp32-grad should be set, and --dpp-backend should be no_c10d'
)
group
.
add_argument
(
'--update-freq'
,
default
=
'1'
,
metavar
=
'N1,N2,...,N_K'
,
type
=
lambda
uf
:
eval_str_list
(
uf
,
type
=
int
),
help
=
'update parameters every N_i batches, when in epoch i'
)
group
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
'0.25'
,
type
=
eval_str_list
,
metavar
=
'LR_1,LR_2,...,LR_N'
,
help
=
'learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)'
)
group
.
add_argument
(
'--stop-min-lr'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'stop training when the learning rate reaches this minimum'
)
# fmt: on
return
group
def
add_checkpoint_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Checkpointing'
)
# fmt: off
group
.
add_argument
(
'--save-dir'
,
metavar
=
'DIR'
,
default
=
'checkpoints'
,
help
=
'path to save checkpoints'
)
group
.
add_argument
(
'--tmp-save-dir'
,
metavar
=
'DIR'
,
default
=
'./'
,
help
=
'path to temporarily save checkpoints'
)
group
.
add_argument
(
'--restore-file'
,
default
=
'checkpoint_last.pt'
,
help
=
'filename from which to load checkpoint '
'(default: <save-dir>/checkpoint_last.pt'
)
group
.
add_argument
(
'--finetune-from-model'
,
type
=
str
,
help
=
"finetune from a pretrained model; note that meters and lr scheduler will be reset"
)
group
.
add_argument
(
'--load-from-ema'
,
action
=
"store_true"
,
help
=
"finetune from a pretrained model; note that meters and lr scheduler will be reset"
)
group
.
add_argument
(
'--reset-dataloader'
,
action
=
'store_true'
,
help
=
'if set, does not reload dataloader state from the checkpoint'
)
group
.
add_argument
(
'--reset-lr-scheduler'
,
action
=
'store_true'
,
help
=
'if set, does not load lr scheduler state from the checkpoint'
)
group
.
add_argument
(
'--reset-meters'
,
action
=
'store_true'
,
help
=
'if set, does not load meters from the checkpoint'
)
group
.
add_argument
(
'--reset-optimizer'
,
action
=
'store_true'
,
help
=
'if set, does not load optimizer state from the checkpoint'
)
group
.
add_argument
(
'--optimizer-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
help
=
'a dictionary used to override optimizer args when loading a checkpoint'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'save a checkpoint every N epochs'
)
group
.
add_argument
(
'--save-interval-updates'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'save a checkpoint (and validate) every N updates'
)
group
.
add_argument
(
'--keep-interval-updates'
,
type
=
int
,
default
=-
1
,
metavar
=
'N'
,
help
=
'keep the last N checkpoints saved with --save-interval-updates'
)
group
.
add_argument
(
'--keep-last-epochs'
,
type
=
int
,
default
=-
1
,
metavar
=
'N'
,
help
=
'keep last N epoch checkpoints'
)
group
.
add_argument
(
'--keep-best-checkpoints'
,
type
=
int
,
default
=-
1
,
metavar
=
'N'
,
help
=
'keep best N checkpoints based on scores'
)
group
.
add_argument
(
'--no-save'
,
action
=
'store_true'
,
help
=
'don
\'
t save models or checkpoints'
)
group
.
add_argument
(
'--no-epoch-checkpoints'
,
action
=
'store_true'
,
help
=
'only store last and best checkpoints'
)
group
.
add_argument
(
'--no-last-checkpoints'
,
action
=
'store_true'
,
help
=
'don
\'
t store last checkpoints'
)
group
.
add_argument
(
'--no-save-optimizer-state'
,
action
=
'store_true'
,
help
=
'don
\'
t save optimizer-state as part of checkpoint'
)
group
.
add_argument
(
'--best-checkpoint-metric'
,
type
=
str
,
default
=
'loss'
,
help
=
'metric to use for saving "best" checkpoints'
)
group
.
add_argument
(
'--maximize-best-checkpoint-metric'
,
action
=
'store_true'
,
help
=
'select the largest metric value for saving "best" checkpoints'
)
group
.
add_argument
(
'--patience'
,
type
=
int
,
default
=-
1
,
metavar
=
'N'
,
help
=
"early stop training if valid performance doesn't "
"improve for N consecutive validation runs; note "
"that this is influenced by --validate-interval"
)
group
.
add_argument
(
'--checkpoint-suffix'
,
type
=
str
,
default
=
""
,
help
=
"suffix to add to the checkpoint file name"
)
# fmt: on
return
group
def
add_common_eval_args
(
group
):
# fmt: off
group
.
add_argument
(
'--path'
,
metavar
=
'FILE'
,
help
=
'path(s) to model file(s), colon separated'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
help
=
'only print final scores'
)
group
.
add_argument
(
'--model-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
help
=
'a dictionary used to override model args at generation '
'that were used during model training'
)
group
.
add_argument
(
'--results-path'
,
metavar
=
'RESDIR'
,
type
=
str
,
default
=
None
,
help
=
'path to save eval results (optional)"'
)
# fmt: on
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Model configuration'
)
# fmt: off
# Model definitions can be found under unicore/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority)
from
unicore.models
import
ARCH_MODEL_REGISTRY
group
.
add_argument
(
'--arch'
,
'-a'
,
default
=
'fconv'
,
metavar
=
'ARCH'
,
required
=
True
,
choices
=
ARCH_MODEL_REGISTRY
.
keys
(),
help
=
'Model Architecture'
)
# fmt: on
return
group
Uni-Core-main/unicore/registry.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
REGISTRIES
=
{}
def
setup_registry
(
registry_name
:
str
,
base_class
=
None
,
default
=
None
,
):
assert
registry_name
.
startswith
(
'--'
)
registry_name
=
registry_name
[
2
:].
replace
(
'-'
,
'_'
)
REGISTRY
=
{}
REGISTRY_CLASS_NAMES
=
set
()
# maintain a registry of all registries
if
registry_name
in
REGISTRIES
:
return
# registry already exists
REGISTRIES
[
registry_name
]
=
{
'registry'
:
REGISTRY
,
'default'
:
default
,
}
def
build_x
(
args
,
*
extra_args
,
**
extra_kwargs
):
choice
=
getattr
(
args
,
registry_name
,
None
)
if
choice
is
None
:
return
None
cls
=
REGISTRY
[
choice
]
if
hasattr
(
cls
,
'build_'
+
registry_name
):
builder
=
getattr
(
cls
,
'build_'
+
registry_name
)
else
:
builder
=
cls
set_defaults
(
args
,
cls
)
return
builder
(
args
,
*
extra_args
,
**
extra_kwargs
)
def
register_x
(
name
):
def
register_x_cls
(
cls
):
if
name
in
REGISTRY
:
raise
ValueError
(
'Cannot register duplicate {} ({})'
.
format
(
registry_name
,
name
))
if
cls
.
__name__
in
REGISTRY_CLASS_NAMES
:
raise
ValueError
(
'Cannot register {} with duplicate class name ({})'
.
format
(
registry_name
,
cls
.
__name__
,
)
)
if
base_class
is
not
None
and
not
issubclass
(
cls
,
base_class
):
raise
ValueError
(
'{} must extend {}'
.
format
(
cls
.
__name__
,
base_class
.
__name__
))
REGISTRY
[
name
]
=
cls
REGISTRY_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_x_cls
return
build_x
,
register_x
,
REGISTRY
def
set_defaults
(
args
,
cls
):
"""Helper to set default arguments based on *add_args*."""
if
not
hasattr
(
cls
,
'add_args'
):
return
parser
=
argparse
.
ArgumentParser
(
argument_default
=
argparse
.
SUPPRESS
,
allow_abbrev
=
False
)
cls
.
add_args
(
parser
)
# copied from argparse.py:
defaults
=
argparse
.
Namespace
()
for
action
in
parser
.
_actions
:
if
action
.
dest
is
not
argparse
.
SUPPRESS
:
if
not
hasattr
(
defaults
,
action
.
dest
):
if
action
.
default
is
not
argparse
.
SUPPRESS
:
setattr
(
defaults
,
action
.
dest
,
action
.
default
)
for
key
,
default_value
in
vars
(
defaults
).
items
():
if
not
hasattr
(
args
,
key
):
setattr
(
args
,
key
,
default_value
)
Uni-Core-main/unicore/tasks/__init__.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import
argparse
import
importlib
import
os
from
.unicore_task
import
UnicoreTask
# register dataclass
TASK_REGISTRY
=
{}
TASK_CLASS_NAMES
=
set
()
def
setup_task
(
args
,
**
kwargs
):
return
TASK_REGISTRY
[
args
.
task
].
setup_task
(
args
,
**
kwargs
)
def
register_task
(
name
):
"""
New tasks can be added to unicore with the
:func:`~unicore.tasks.register_task` function decorator.
For example::
@register_task('classification')
class ClassificationTask(UnicoreTask):
(...)
.. note::
All Tasks must implement the :class:`~unicore.tasks.UnicoreTask`
interface.
Args:
name (str): the name of the task
"""
def
register_task_cls
(
cls
):
if
name
in
TASK_REGISTRY
:
raise
ValueError
(
"Cannot register duplicate task ({})"
.
format
(
name
))
if
not
issubclass
(
cls
,
UnicoreTask
):
raise
ValueError
(
"Task ({}: {}) must extend UnicoreTask"
.
format
(
name
,
cls
.
__name__
)
)
if
cls
.
__name__
in
TASK_CLASS_NAMES
:
raise
ValueError
(
"Cannot register task with duplicate class name ({})"
.
format
(
cls
.
__name__
)
)
TASK_REGISTRY
[
name
]
=
cls
TASK_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_task_cls
# automatically import any Python files in the tasks/ directory
tasks_dir
=
os
.
path
.
dirname
(
__file__
)
for
file
in
os
.
listdir
(
tasks_dir
):
path
=
os
.
path
.
join
(
tasks_dir
,
file
)
if
(
not
file
.
startswith
(
"_"
)
and
not
file
.
startswith
(
"."
)
and
(
file
.
endswith
(
".py"
)
or
os
.
path
.
isdir
(
path
))
):
task_name
=
file
[:
file
.
find
(
".py"
)]
if
file
.
endswith
(
".py"
)
else
file
module
=
importlib
.
import_module
(
"unicore.tasks."
+
task_name
)
# expose `task_parser` for sphinx
if
task_name
in
TASK_REGISTRY
:
parser
=
argparse
.
ArgumentParser
(
add_help
=
False
)
group_task
=
parser
.
add_argument_group
(
"Task name"
)
# fmt: off
group_task
.
add_argument
(
'--task'
,
metavar
=
task_name
,
help
=
'Enable this task with: ``--task='
+
task_name
+
'``'
)
# fmt: on
group_args
=
parser
.
add_argument_group
(
"Additional command-line arguments"
)
TASK_REGISTRY
[
task_name
].
add_args
(
group_args
)
globals
()[
task_name
+
"_parser"
]
=
parser
Uni-Core-main/unicore/tasks/unicore_task.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
os
import
warnings
from
argparse
import
Namespace
from
typing
import
Any
,
Callable
,
Dict
,
List
import
torch
from
unicore
import
metrics
,
utils
from
unicore.data
import
UnicoreDataset
,
data_utils
,
iterators
logger
=
logging
.
getLogger
(
__name__
)
class
StatefulContainer
(
object
):
_state
:
Dict
[
str
,
Any
]
=
dict
()
_factories
:
Dict
[
str
,
Callable
[[],
Any
]]
=
dict
()
def
add_factory
(
self
,
name
,
factory
:
Callable
[[],
Any
]):
self
.
_factories
[
name
]
=
factory
def
merge_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
]):
self
.
_state
.
update
(
state_dict
)
@
property
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_state
def
__getattr__
(
self
,
name
):
if
name
not
in
self
.
_state
and
name
in
self
.
_factories
:
self
.
_state
[
name
]
=
self
.
_factories
[
name
]()
if
name
in
self
.
_state
:
return
self
.
_state
[
name
]
raise
AttributeError
(
f
"Task state has no factory for attribute
{
name
}
"
)
class
UnicoreTask
(
object
):
"""
Tasks store dictionaries and provide helpers for loading/iterating over
Datasets, initializing the Model/Loss and calculating the loss.
Tasks have limited statefulness. In particular, state that needs to be
saved to/loaded from checkpoints needs to be stored in the `self.state`
:class:`StatefulContainer` object. For example::
self.state.add_factory("dictionary", self.load_dictionary)
print(self.state.dictionary) # calls self.load_dictionary()
This is necessary so that when loading checkpoints, we can properly
recreate the task state after initializing the task instance.
"""
@
classmethod
def
add_args
(
cls
,
parser
):
"""Add task-specific arguments to the parser."""
pass
@
staticmethod
def
logging_outputs_can_be_summed
(
loss
,
is_train
)
->
bool
:
"""
Whether the logging outputs returned by `train_step` and `valid_step` can
be summed across workers prior to calling `reduce_metrics`.
Setting this to True will improves distributed training speed.
"""
return
loss
.
logging_outputs_can_be_summed
(
is_train
)
args
:
Namespace
datasets
:
Dict
[
str
,
UnicoreDataset
]
dataset_to_epoch_iter
:
Dict
[
UnicoreDataset
,
Any
]
state
:
StatefulContainer
=
None
def
__init__
(
self
,
args
:
Namespace
,
**
kwargs
):
self
.
args
=
args
self
.
datasets
=
dict
()
self
.
dataset_to_epoch_iter
=
dict
()
self
.
state
=
StatefulContainer
()
@
classmethod
def
setup_task
(
cls
,
args
:
Namespace
,
**
kwargs
):
"""Setup the task (e.g., load dictionaries).
Args:
args (Namespace): parsed command-line arguments
"""
return
cls
(
args
,
**
kwargs
)
def
has_sharded_data
(
self
,
split
):
return
os
.
pathsep
in
getattr
(
self
.
args
,
"data"
,
""
)
def
load_dataset
(
self
,
split
:
str
,
combine
:
bool
=
False
,
**
kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
combine (bool): combines a split segmented into pieces into one dataset
"""
raise
NotImplementedError
def
dataset
(
self
,
split
):
"""
Return a loaded dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Returns:
a :class:`~unicore.data.UnicoreDataset` corresponding to *split*
"""
from
unicore.data
import
UnicoreDataset
if
split
not
in
self
.
datasets
:
raise
KeyError
(
"Dataset not loaded: "
+
split
)
if
not
isinstance
(
self
.
datasets
[
split
],
UnicoreDataset
):
raise
TypeError
(
"Datasets are expected to be of type UnicoreDataset"
)
return
self
.
datasets
[
split
]
def
can_reuse_epoch_itr
(
self
,
dataset
):
# We can reuse the epoch iterator across epochs as long as the dataset
# hasn't disabled it. We default to ``False`` here, although in practice
# this will be ``True`` for most datasets that inherit from
# ``UnicoreDataset`` due to the base implementation there.
return
getattr
(
dataset
,
"can_reuse_epoch_itr_across_epochs"
,
False
)
def
get_batch_iterator
(
self
,
dataset
,
batch_size
=
None
,
ignore_invalid_inputs
=
False
,
required_batch_size_multiple
=
1
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
,
num_workers
=
0
,
epoch
=
1
,
data_buffer_size
=
0
,
disable_iterator_cache
=
False
,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~unicore.data.UnicoreDataset): dataset to batch
batch_size (int, optional): max number of samples in each
batch (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `UnicoreTask::can_reuse_epoch_itr`)
(default: False).
Returns:
~unicore.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
can_reuse_epoch_itr
=
not
disable_iterator_cache
and
self
.
can_reuse_epoch_itr
(
dataset
)
if
can_reuse_epoch_itr
and
dataset
in
self
.
dataset_to_epoch_iter
:
logger
.
info
(
"reusing EpochBatchIterator for epoch {}"
.
format
(
epoch
))
return
self
.
dataset_to_epoch_iter
[
dataset
]
else
:
logger
.
info
(
"get EpochBatchIterator for epoch {}"
.
format
(
epoch
))
assert
isinstance
(
dataset
,
UnicoreDataset
)
# initialize the dataset with the correct starting epoch
dataset
.
set_epoch
(
epoch
)
# get indices ordered by example size
with
data_utils
.
numpy_seed
(
seed
):
indices
=
dataset
.
ordered_indices
()
# create mini-batches with given size constraints
batch_sampler
=
dataset
.
batch_by_size
(
indices
,
batch_size
=
batch_size
,
required_batch_size_multiple
=
required_batch_size_multiple
,
)
# return a reusable, sharded iterator
epoch_iter
=
iterators
.
EpochBatchIterator
(
dataset
=
dataset
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
,
seed
=
seed
,
num_shards
=
num_shards
,
shard_id
=
shard_id
,
num_workers
=
num_workers
,
epoch
=
epoch
,
buffer_size
=
data_buffer_size
,
disable_shuffling
=
self
.
disable_shuffling
(),
)
if
can_reuse_epoch_itr
:
self
.
dataset_to_epoch_iter
[
dataset
]
=
epoch_iter
return
epoch_iter
def
build_model
(
self
,
args
:
Namespace
):
"""
Build the :class:`~unicore.models.BaseUnicoreModel` instance for this
task.
Returns:
a :class:`~unicore.models.BaseUnicoreModel` instance
"""
from
unicore
import
models
return
models
.
build_model
(
args
,
self
)
def
build_loss
(
self
,
args
:
Namespace
):
"""
Build the :class:`~unicore.losses.UnicoreLoss` instance for
this task.
Args:
args (Namespace): configration object
Returns:
a :class:`~unicore.losses.UnicoreLoss` instance
"""
from
unicore
import
losses
return
losses
.
build_loss
(
args
,
self
)
def
train_step
(
self
,
sample
,
model
,
loss
,
optimizer
,
update_num
,
ignore_grad
=
False
):
"""
Do forward and backward, and return the loss as computed by *loss*
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~unicore.data.UnicoreDataset`.
model (~unicore.models.BaseUnicoreModel): the model
loss (~unicore.losses.UnicoreLoss): the loss
optimizer (~unicore.optim.UnicoreOptimizer): the optimizer
update_num (int): the current update
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
"""
model
.
train
()
model
.
set_num_updates
(
update_num
)
with
torch
.
autograd
.
profiler
.
record_function
(
"forward"
):
loss
,
sample_size
,
logging_output
=
loss
(
model
,
sample
)
if
ignore_grad
:
loss
*=
0
with
torch
.
autograd
.
profiler
.
record_function
(
"backward"
):
optimizer
.
backward
(
loss
)
return
loss
,
sample_size
,
logging_output
def
valid_step
(
self
,
sample
,
model
,
loss
,
test
=
False
):
model
.
eval
()
with
torch
.
no_grad
():
loss
,
sample_size
,
logging_output
=
loss
(
model
,
sample
)
return
loss
,
sample_size
,
logging_output
def
optimizer_step
(
self
,
optimizer
,
model
,
update_num
):
optimizer
.
step
()
def
build_dataset_for_inference
(
self
,
src_tokens
:
List
[
torch
.
Tensor
],
src_lengths
:
List
[
int
],
**
kwargs
)
->
torch
.
utils
.
data
.
Dataset
:
raise
NotImplementedError
def
begin_epoch
(
self
,
epoch
,
model
):
"""Hook function called before the start of each epoch."""
pass
def
begin_valid_epoch
(
self
,
epoch
,
model
):
"""Hook function called before the start of each validation epoch."""
pass
def
reduce_metrics
(
self
,
logging_outputs
,
loss
,
split
=
'train'
):
"""Aggregate logging outputs from data parallel training."""
if
not
any
(
"bsz"
in
log
for
log
in
logging_outputs
):
warnings
.
warn
(
"bsz not found in Loss logging outputs, cannot log bsz"
)
else
:
bsz
=
sum
(
log
.
get
(
"bsz"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"bsz"
,
bsz
,
priority
=
190
,
round
=
1
)
loss
.
__class__
.
reduce_metrics
(
logging_outputs
,
split
)
def
state_dict
(
self
):
if
self
.
state
is
not
None
:
return
self
.
state
.
state_dict
return
{}
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
]):
if
self
.
state
is
not
None
:
self
.
state
.
merge_state_dict
(
state_dict
)
def
disable_shuffling
(
self
)
->
bool
:
return
False
\ No newline at end of file
Uni-Core-main/unicore/trainer.py
0 → 100644
View file @
a1c29028
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
import
contextlib
import
logging
import
os
import
sys
import
time
from
itertools
import
chain
from
typing
import
Any
,
Dict
,
List
import
torch
from
unicore
import
checkpoint_utils
,
models
,
optim
,
utils
from
unicore.distributed
import
utils
as
distributed_utils
from
unicore.logging
import
meters
,
metrics
from
unicore.nan_detector
import
NanDetector
from
unicore.optim
import
lr_scheduler
from
unicore.utils
import
tensor_tree_map
logger
=
logging
.
getLogger
(
__name__
)
class
ExponentialMovingAverage
:
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def
__init__
(
self
,
model
:
torch
.
nn
.
Module
,
decay
:
float
):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super
(
ExponentialMovingAverage
,
self
).
__init__
()
with
torch
.
no_grad
():
clone_param
=
lambda
t
:
t
.
clone
().
detach
().
float
()
self
.
params
=
tensor_tree_map
(
clone_param
,
model
.
state_dict
())
self
.
decay
=
decay
def
_update_state_dict_
(
self
,
update
,
state_dict
):
with
torch
.
no_grad
():
for
k
,
v
in
update
.
items
():
if
state_dict
[
k
].
device
!=
v
.
device
:
state_dict
[
k
]
=
state_dict
[
k
].
to
(
v
.
device
)
stored
=
state_dict
[
k
]
if
not
isinstance
(
v
,
torch
.
Tensor
):
self
.
_update_state_dict_
(
v
,
stored
)
else
:
diff
=
stored
-
v
.
float
()
diff
*=
1
-
self
.
decay
stored
-=
diff
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
self
.
params
=
state_dict
[
"params"
]
self
.
decay
=
state_dict
[
"decay"
]
if
"decay"
in
state_dict
else
self
.
decay
def
state_dict
(
self
)
->
dict
:
return
{
"params"
:
self
.
params
,
"decay"
:
self
.
decay
,
}
class
Trainer
(
object
):
"""Main class for data parallel training.
This class supports synchronous distributed data parallel training,
where multiple workers each have a full model replica and gradients
are accumulated across workers before each update. We use
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
communication of the gradients across workers.
"""
def
__init__
(
self
,
args
,
task
,
model
,
loss
):
self
.
args
=
args
self
.
task
=
task
# catalog shared parameters
shared_params
=
_catalog_shared_params
(
model
)
self
.
cuda
=
torch
.
cuda
.
is_available
()
if
self
.
cuda
:
self
.
device
=
torch
.
device
(
"cuda"
)
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
# copy model and loss to current device/dtype
self
.
_loss
=
loss
self
.
_model
=
model
if
args
.
fp16
:
self
.
_loss
=
self
.
_loss
.
half
()
self
.
_model
=
self
.
_model
.
half
()
elif
args
.
bf16
:
self
.
_loss
=
self
.
_loss
.
bfloat16
()
self
.
_model
=
self
.
_model
.
bfloat16
()
if
(
# the DistributedUnicoreModel wrapper will handle moving to device,
# so only handle cases which don't use the wrapper
not
self
.
use_distributed_wrapper
):
self
.
_loss
=
self
.
_loss
.
to
(
device
=
self
.
device
)
self
.
_model
=
self
.
_model
.
to
(
device
=
self
.
device
)
# check that shared parameters are preserved after device transfer
for
shared_param
in
shared_params
:
ref
=
_get_module_by_path
(
self
.
_model
,
shared_param
[
0
])
for
path
in
shared_param
[
1
:]:
logger
.
info
(
"detected shared parameter: {} <- {}"
.
format
(
shared_param
[
0
],
path
)
)
_set_module_by_path
(
self
.
_model
,
path
,
ref
)
self
.
_dummy_batch
=
None
# indicates we don't have a dummy batch at first
self
.
_total_train_steps
=
None
self
.
_lr_scheduler
=
None
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
_optimizer
=
None
self
.
_warn_once
=
set
()
self
.
_wrapped_loss
=
None
self
.
_wrapped_model
=
None
if
self
.
cuda
and
self
.
data_parallel_world_size
>
1
:
self
.
_grad_norm_buf
=
torch
.
cuda
.
DoubleTensor
(
self
.
data_parallel_world_size
)
else
:
self
.
_grad_norm_buf
=
None
# get detailed cuda environment
if
self
.
cuda
:
self
.
cuda_env
=
utils
.
CudaEnvironment
()
if
self
.
data_parallel_world_size
>
1
:
self
.
cuda_env_arr
=
distributed_utils
.
all_gather_list
(
self
.
cuda_env
,
group
=
distributed_utils
.
get_global_group
()
)
else
:
self
.
cuda_env_arr
=
[
self
.
cuda_env
]
if
self
.
data_parallel_rank
==
0
:
utils
.
CudaEnvironment
.
pretty_print_cuda_env_list
(
self
.
cuda_env_arr
)
else
:
self
.
cuda_env
=
None
self
.
cuda_env_arr
=
None
# add ema
if
args
.
ema_decay
>
0
and
self
.
data_parallel_rank
==
0
:
self
.
ema
=
ExponentialMovingAverage
(
self
.
_model
,
decay
=
args
.
ema_decay
)
else
:
self
.
ema
=
None
metrics
.
log_start_time
(
"wall"
,
priority
=
790
,
round
=
2
)
self
.
_start_time
=
time
.
time
()
self
.
_previous_training_time
=
0
self
.
_cumulative_training_time
=
None
def
reinitialize
(
self
):
"""Reinitialize the Trainer, typically after model params change."""
self
.
_lr_scheduler
=
None
self
.
_optimizer
=
None
self
.
_wrapped_loss
=
None
self
.
_wrapped_model
=
None
@
property
def
data_parallel_world_size
(
self
):
if
self
.
args
.
distributed_world_size
==
1
:
return
1
return
distributed_utils
.
get_data_parallel_world_size
()
@
property
def
data_parallel_process_group
(
self
):
return
distributed_utils
.
get_data_parallel_group
()
@
property
def
data_parallel_rank
(
self
):
if
self
.
args
.
distributed_world_size
==
1
:
return
0
return
distributed_utils
.
get_data_parallel_rank
()
@
property
def
is_data_parallel_master
(
self
):
# NOTE: this returns true for all model parallel replicas with data
# parallel rank 0
return
self
.
data_parallel_rank
==
0
@
property
def
use_distributed_wrapper
(
self
)
->
bool
:
return
self
.
data_parallel_world_size
>
1
@
property
def
should_save_checkpoint_on_current_rank
(
self
)
->
bool
:
"""Indicates whether to save checkpoints on the current DDP rank."""
return
self
.
is_data_parallel_master
@
property
def
checkpoint_suffix
(
self
)
->
str
:
"""Suffix to add to the checkpoint file name."""
return
self
.
args
.
checkpoint_suffix
or
""
@
property
def
loss
(
self
):
if
self
.
_wrapped_loss
is
None
:
if
utils
.
has_parameters
(
self
.
_loss
)
and
self
.
use_distributed_wrapper
:
self
.
_wrapped_loss
=
models
.
DistributedUnicoreModel
(
self
.
args
,
self
.
_loss
,
process_group
=
self
.
data_parallel_process_group
,
device
=
self
.
device
,
)
else
:
self
.
_wrapped_loss
=
self
.
_loss
return
self
.
_wrapped_loss
@
property
def
model
(
self
):
if
self
.
_wrapped_model
is
None
:
if
self
.
use_distributed_wrapper
:
self
.
_wrapped_model
=
models
.
DistributedUnicoreModel
(
self
.
args
,
self
.
_model
,
process_group
=
self
.
data_parallel_process_group
,
device
=
self
.
device
,
)
else
:
self
.
_wrapped_model
=
self
.
_model
return
self
.
_wrapped_model
@
property
def
optimizer
(
self
):
if
self
.
_optimizer
is
None
:
self
.
_build_optimizer
()
return
self
.
_optimizer
@
property
def
lr_scheduler
(
self
):
if
self
.
_lr_scheduler
is
None
:
self
.
_build_optimizer
()
# this will initialize self._lr_scheduler
return
self
.
_lr_scheduler
def
_build_optimizer
(
self
):
params
=
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
chain
(
self
.
model
.
parameters
(),
self
.
loss
.
parameters
()),
)
)
if
self
.
args
.
per_sample_clip_norm
>
0
:
assert
self
.
args
.
ddp_backend
==
"no_c10d"
assert
self
.
args
.
batch_size
==
1
if
self
.
args
.
fp16
or
self
.
args
.
bf16
:
if
self
.
cuda
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
<
7
:
logger
.
info
(
"NOTE: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster"
)
self
.
_optimizer
=
optim
.
FP16Optimizer
.
build_optimizer
(
self
.
args
,
params
)
if
self
.
args
.
allreduce_fp32_grad
:
assert
self
.
args
.
ddp_backend
==
"no_c10d"
if
self
.
args
.
per_sample_clip_norm
>
0
:
assert
self
.
args
.
allreduce_fp32_grad
else
:
if
self
.
cuda
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
7
:
logger
.
info
(
"NOTE: your device may support faster training with --fp16"
)
self
.
_optimizer
=
optim
.
build_optimizer
(
self
.
args
,
params
)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self
.
_lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
,
self
.
_total_train_steps
,
)
self
.
_lr_scheduler
.
step_update
(
0
)
def
state_dict
(
self
):
state_dict
=
{
"args"
:
self
.
args
,
"model"
:
self
.
model
.
state_dict
(),
"loss"
:
(
self
.
loss
.
state_dict
()
if
utils
.
has_parameters
(
self
.
loss
)
else
None
),
"optimizer_history"
:
(
self
.
_optim_history
or
[])
+
[
{
"loss_name"
:
self
.
get_loss
().
__class__
.
__name__
,
"optimizer_name"
:
self
.
optimizer
.
__class__
.
__name__
,
"lr_scheduler_state"
:
self
.
lr_scheduler
.
state_dict
(),
"num_updates"
:
self
.
get_num_updates
(),
}
],
"task_state"
:
self
.
task
.
state_dict
()
if
self
.
task
is
not
None
else
{},
"extra_state"
:
{
"metrics"
:
metrics
.
state_dict
(),
"previous_training_time"
:
self
.
cumulative_training_time
(),
},
}
if
not
self
.
args
.
no_save_optimizer_state
:
state_dict
[
"last_optimizer_state"
]
=
self
.
optimizer
.
state_dict
()
if
self
.
ema
is
not
None
:
state_dict
[
"ema"
]
=
self
.
ema
.
state_dict
()
return
state_dict
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
logger
.
info
(
f
"Saving checkpoint to
{
filename
}
"
)
# call state_dict on all ranks in case it needs internal communication
state_dict
=
utils
.
move_to_cpu
(
self
.
state_dict
())
state_dict
[
"extra_state"
].
update
(
extra_state
)
if
self
.
should_save_checkpoint_on_current_rank
:
checkpoint_utils
.
torch_persistent_save
(
state_dict
,
filename
,
)
logger
.
info
(
f
"Finished saving checkpoint to
{
filename
}
"
)
def
load_checkpoint
(
self
,
filename
,
reset_optimizer
=
False
,
reset_lr_scheduler
=
False
,
optimizer_overrides
=
None
,
reset_meters
=
False
,
):
"""
Load all training state from a checkpoint file.
rank = 0 will load the checkpoint, and then broadcast it to all
other ranks.
"""
extra_state
,
self
.
_optim_history
,
last_optim_state
=
None
,
[],
None
logger
.
info
(
f
"Preparing to load checkpoint
{
filename
}
"
)
is_distributed
=
self
.
data_parallel_world_size
>
1
is_master
=
self
.
data_parallel_rank
==
0
bexists
=
None
if
is_master
:
bexists
=
os
.
path
.
isfile
(
filename
)
if
is_distributed
:
bexists
=
distributed_utils
.
broadcast_object
(
bexists
,
src_rank
=
0
,
group
=
self
.
data_parallel_process_group
,
dist_device
=
self
.
device
,
)
had_loaded_model
=
False
if
bexists
:
state
=
None
if
is_master
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
filename
,
)
if
is_distributed
:
logger
.
info
(
"Broadcast checkpoint from rank_0"
)
state
=
distributed_utils
.
broadcast_object
(
state
,
src_rank
=
0
,
group
=
self
.
data_parallel_process_group
,
dist_device
=
self
.
device
,
)
last_optim_state
=
state
.
get
(
"last_optimizer_state"
,
None
)
ema_state
=
state
.
get
(
"ema"
,
None
)
# load model parameters
try
:
if
self
.
args
.
load_from_ema
:
logger
.
info
(
"loading ema state to model"
)
errors
=
self
.
model
.
load_state_dict
(
ema_state
[
"params"
],
strict
=
False
,
model_args
=
self
.
args
)
else
:
errors
=
self
.
model
.
load_state_dict
(
state
[
"model"
],
strict
=
False
,
model_args
=
self
.
args
)
# save memory for later steps
del
state
[
"model"
]
had_loaded_model
=
True
if
errors
.
missing_keys
:
logger
.
warning
(
"Error in loading model state, missing_keys "
+
str
(
errors
.
missing_keys
)
)
if
errors
.
unexpected_keys
:
logger
.
warning
(
"Error in loading model state, unexpected_keys "
+
str
(
errors
.
unexpected_keys
)
)
if
utils
.
has_parameters
(
self
.
get_loss
()):
self
.
get_loss
().
load_state_dict
(
state
[
"loss"
],
strict
=
True
)
del
state
[
"loss"
]
except
Exception
:
raise
Exception
(
"Cannot load model parameters from checkpoint {}; "
"please ensure that the architectures match."
.
format
(
filename
)
)
extra_state
=
state
[
"extra_state"
]
if
"extra_state"
in
state
else
None
self
.
_optim_history
=
(
state
[
"optimizer_history"
]
if
"optimizer_history"
in
state
else
None
)
if
(
ema_state
is
not
None
and
self
.
ema
is
not
None
and
not
self
.
args
.
load_from_ema
):
logger
.
info
(
f
"Loading EMA state..."
)
self
.
ema
.
load_state_dict
(
ema_state
)
elif
self
.
ema
is
not
None
:
logger
.
info
(
f
"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
_model
,
decay
=
self
.
ema
.
decay
)
if
last_optim_state
is
not
None
and
not
reset_optimizer
:
# rebuild optimizer after loading model, since params may have changed
self
.
_build_optimizer
()
# only reload optimizer and lr_scheduler if they match
last_optim
=
self
.
_optim_history
[
-
1
]
assert
(
last_optim
[
"loss_name"
]
==
self
.
get_loss
().
__class__
.
__name__
),
f
"Loss does not match; please reset the optimizer (--reset-optimizer).
{
last_optim
[
'loss_name'
]
}
vs
{
self
.
get_loss
().
__class__
.
__name__
}
"
assert
(
last_optim
[
"optimizer_name"
]
==
self
.
optimizer
.
__class__
.
__name__
),
f
"Optimizer does not match; please reset the optimizer (--reset-optimizer).
{
last_optim
[
'optimizer_name'
]
}
vs
{
self
.
optimizer
.
__class__
.
__name__
}
"
if
not
reset_lr_scheduler
:
self
.
lr_scheduler
.
load_state_dict
(
last_optim
[
"lr_scheduler_state"
])
self
.
optimizer
.
load_state_dict
(
last_optim_state
,
optimizer_overrides
)
self
.
set_num_updates
(
last_optim
[
"num_updates"
])
if
extra_state
is
not
None
:
itr_state
=
extra_state
[
"train_iterator"
]
epoch
=
itr_state
[
"epoch"
]
if
"previous_training_time"
in
extra_state
:
self
.
_previous_training_time
=
extra_state
[
"previous_training_time"
]
self
.
_start_time
=
time
.
time
()
# self.lr_step(epoch)
if
(
itr_state
.
get
(
"version"
,
1
)
>=
2
and
itr_state
[
"iterations_in_epoch"
]
==
0
):
# reset meters at start of epoch
reset_meters
=
True
if
"metrics"
in
extra_state
and
not
reset_meters
:
metrics
.
load_state_dict
(
extra_state
[
"metrics"
])
# reset TimeMeters, since their start times don't make sense anymore
for
meter
in
metrics
.
get_meters
(
"default"
):
if
isinstance
(
meter
,
meters
.
TimeMeter
):
meter
.
reset
()
logger
.
info
(
"Loaded checkpoint {} (epoch {} @ {} updates)"
.
format
(
filename
,
epoch
,
self
.
get_num_updates
()
)
)
elif
had_loaded_model
:
logger
.
info
(
"Loaded checkpoint {}"
.
format
(
filename
))
else
:
logger
.
info
(
"No existing checkpoint found {}"
.
format
(
filename
))
return
extra_state
def
get_train_iterator
(
self
,
epoch
,
combine
=
True
,
load_dataset
=
True
,
data_selector
=
None
,
shard_batch_itr
=
True
,
disable_iterator_cache
=
False
,
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if
load_dataset
:
logger
.
info
(
"loading train data for epoch {}"
.
format
(
epoch
))
self
.
task
.
load_dataset
(
self
.
args
.
train_subset
,
epoch
=
epoch
,
combine
=
combine
,
data_selector
=
data_selector
,
)
batch_iterator
=
self
.
task
.
get_batch_iterator
(
dataset
=
self
.
task
.
dataset
(
self
.
args
.
train_subset
),
batch_size
=
self
.
args
.
batch_size
,
ignore_invalid_inputs
=
True
,
required_batch_size_multiple
=
self
.
args
.
required_batch_size_multiple
,
seed
=
self
.
args
.
seed
,
num_shards
=
self
.
data_parallel_world_size
if
shard_batch_itr
else
1
,
shard_id
=
self
.
data_parallel_rank
if
shard_batch_itr
else
0
,
num_workers
=
self
.
args
.
num_workers
,
epoch
=
epoch
,
data_buffer_size
=
self
.
args
.
data_buffer_size
,
disable_iterator_cache
=
disable_iterator_cache
,
)
self
.
reset_dummy_batch
(
batch_iterator
.
first_batch
)
return
batch_iterator
def
init_total_train_steps
(
self
,
epoch_itr
):
if
self
.
args
.
max_epoch
>
0
:
self
.
_total_train_steps
=
(
(
len
(
epoch_itr
)
+
1
)
//
self
.
args
.
update_freq
[
0
]
*
self
.
args
.
max_epoch
)
else
:
self
.
_total_train_steps
=
self
.
args
.
max_update
def
get_valid_iterator
(
self
,
subset
,
disable_iterator_cache
=
False
,
):
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
batch_iterator
=
self
.
task
.
get_batch_iterator
(
dataset
=
self
.
task
.
dataset
(
subset
),
batch_size
=
self
.
args
.
batch_size_valid
,
ignore_invalid_inputs
=
self
.
args
.
skip_invalid_size_inputs_valid_test
,
required_batch_size_multiple
=
self
.
args
.
required_batch_size_multiple
,
seed
=
self
.
args
.
seed
,
num_shards
=
self
.
data_parallel_world_size
,
shard_id
=
self
.
data_parallel_rank
,
num_workers
=
self
.
args
.
num_workers
,
# always pass a fixed "epoch" to keep validation data consistent
# across training epochs
epoch
=
1
,
data_buffer_size
=
self
.
args
.
data_buffer_size
,
disable_iterator_cache
=
disable_iterator_cache
,
)
self
.
reset_dummy_batch
(
batch_iterator
.
first_batch
)
return
batch_iterator
def
begin_epoch
(
self
,
epoch
):
"""Called at the beginning of each epoch."""
logger
.
info
(
"begin training epoch {}"
.
format
(
epoch
))
self
.
lr_step_begin_epoch
(
epoch
)
# task specific setup per epoch
self
.
task
.
begin_epoch
(
epoch
,
self
.
get_model
())
def
begin_valid_epoch
(
self
,
epoch
):
"""Called at the beginning of each validation epoch."""
# task specific setup per validation epoch
self
.
task
.
begin_valid_epoch
(
epoch
,
self
.
get_model
())
def
reset_dummy_batch
(
self
,
batch
):
self
.
_dummy_batch
=
batch
@
metrics
.
aggregate
(
"train"
)
def
train_step
(
self
,
samples
,
raise_oom
=
False
):
"""Do forward, backward and parameter update."""
self
.
model
.
train
()
self
.
loss
.
train
()
self
.
zero_grad
()
metrics
.
log_start_time
(
"train_wall"
,
priority
=
800
,
round
=
2
)
# forward and backward pass
logging_outputs
,
sample_size
,
ooms
=
[],
0
,
0
for
i
,
sample
in
enumerate
(
samples
):
# delayed update loop
sample
,
is_dummy_batch
=
self
.
_prepare_sample
(
sample
)
def
maybe_no_sync
():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if
(
self
.
data_parallel_world_size
>
1
and
hasattr
(
self
.
model
,
"no_sync"
)
and
i
<
len
(
samples
)
-
1
):
return
self
.
model
.
no_sync
()
else
:
return
contextlib
.
ExitStack
()
# dummy contextmanager
try
:
with
maybe_no_sync
():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
(),
i
,
self
.
data_parallel_rank
,
):
# forward and backward
loss
,
sample_size_i
,
logging_output
=
self
.
task
.
train_step
(
sample
=
sample
,
model
=
self
.
model
,
loss
=
self
.
loss
,
optimizer
=
self
.
optimizer
,
update_num
=
self
.
get_num_updates
(),
ignore_grad
=
is_dummy_batch
,
)
del
loss
if
self
.
args
.
per_sample_clip_norm
>
0
:
self
.
optimizer
.
per_sample_clip_grad_norm
(
self
.
args
.
per_sample_clip_norm
)
logging_outputs
.
append
(
logging_output
)
sample_size
+=
sample_size_i
# emptying the CUDA cache after the first step can
# reduce the chance of OOM
if
self
.
cuda
and
self
.
get_num_updates
()
==
0
:
torch
.
cuda
.
empty_cache
()
except
RuntimeError
as
e
:
if
"out of memory"
in
str
(
e
):
self
.
_log_oom
(
e
)
if
raise_oom
:
raise
e
logger
.
warning
(
"attempting to recover from OOM in forward/backward pass"
)
ooms
+=
1
self
.
zero_grad
()
if
self
.
cuda
:
torch
.
cuda
.
empty_cache
()
if
self
.
args
.
distributed_world_size
==
1
:
return
None
else
:
raise
e
if
is_dummy_batch
:
if
torch
.
is_tensor
(
sample_size
):
sample_size
.
zero_
()
else
:
sample_size
*=
0.0
if
torch
.
is_tensor
(
sample_size
):
sample_size
=
sample_size
.
float
()
else
:
sample_size
=
float
(
sample_size
)
local_sample_size
=
sample_size
# gather logging outputs from all replicas
if
self
.
_sync_stats
():
train_time
=
self
.
_local_cumulative_training_time
()
logging_outputs
,
(
sample_size
,
ooms
,
total_train_time
,
)
=
self
.
_aggregate_logging_outputs
(
logging_outputs
,
sample_size
,
ooms
,
train_time
,
ignore
=
is_dummy_batch
,
is_train
=
True
,
)
self
.
_cumulative_training_time
=
(
total_train_time
/
self
.
data_parallel_world_size
)
overflow
=
False
try
:
with
torch
.
autograd
.
profiler
.
record_function
(
"reduce-grads"
):
# reduce gradients across workers
self
.
optimizer
.
all_reduce_grads
(
self
.
model
)
if
utils
.
has_parameters
(
self
.
loss
):
self
.
optimizer
.
all_reduce_grads
(
self
.
loss
)
with
torch
.
autograd
.
profiler
.
record_function
(
"multiply-grads"
):
# multiply gradients by (data_parallel_size / sample_size) since
# DDP normalizes by the number of data parallel workers for
# improved fp16 precision.
# Thus we get (sum_of_gradients / sample_size) at the end.
# In case of fp16, this step also undoes loss scaling.
# (Debugging note: Some optimizers perform this scaling on the
# fly, so inspecting model.parameters() or optimizer.params may
# still show the original, unscaled gradients.)
numer
=
self
.
data_parallel_world_size
if
self
.
_sync_stats
()
else
1
self
.
optimizer
.
multiply_grads
(
numer
/
(
sample_size
or
1.0
))
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
# way that avoids CPU/device transfers in case sample_size is a GPU or
# TPU object. The assumption is that the gradient itself is also 0.
with
torch
.
autograd
.
profiler
.
record_function
(
"clip-grads"
):
# clip grads
grad_norm
=
self
.
clip_grad_norm
(
self
.
args
.
clip_norm
)
self
.
_check_grad_norms
(
grad_norm
)
if
not
torch
.
isfinite
(
grad_norm
).
all
():
# check local gradnorm single GPU case, trigger NanDetector
raise
FloatingPointError
(
"gradients are Nan/Inf"
)
with
torch
.
autograd
.
profiler
.
record_function
(
"optimizer"
):
# fixed the seed in case for the stochastic rounding in different ranks
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
()):
# take an optimization step
self
.
task
.
optimizer_step
(
self
.
optimizer
,
model
=
self
.
model
,
update_num
=
self
.
get_num_updates
(),
)
if
self
.
ema
is
not
None
:
with
torch
.
autograd
.
profiler
.
record_function
(
"ema"
):
self
.
ema
.
update
(
self
.
model
)
except
FloatingPointError
:
# re-run the forward and backward pass with hooks attached to print
# out where it fails
self
.
zero_grad
()
with
NanDetector
(
self
.
get_model
()):
for
i
,
sample
in
enumerate
(
samples
):
sample
,
_
=
self
.
_prepare_sample
(
sample
)
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
(),
i
,
self
.
data_parallel_rank
,
):
self
.
task
.
train_step
(
sample
,
self
.
model
,
self
.
loss
,
self
.
optimizer
,
self
.
get_num_updates
(),
ignore_grad
=
False
,
)
raise
except
OverflowError
as
e
:
overflow
=
True
logger
.
info
(
f
"NOTE: gradient overflow detected, ignoring gradient,
{
str
(
e
)
}
"
)
grad_norm
=
torch
.
tensor
(
0.0
).
cuda
()
self
.
zero_grad
()
except
RuntimeError
as
e
:
if
"out of memory"
in
str
(
e
):
self
.
_log_oom
(
e
)
logger
.
error
(
"OOM during optimization, irrecoverable"
)
raise
e
logging_output
=
None
if
not
overflow
:
self
.
set_num_updates
(
self
.
get_num_updates
()
+
1
)
if
self
.
cuda
and
self
.
cuda_env
is
not
None
:
# log minimum free memory over the iteration
gb_used
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
/
1024
/
1024
torch
.
cuda
.
reset_peak_memory_stats
()
gb_free
=
self
.
cuda_env
.
total_memory_in_GB
-
gb_used
metrics
.
log_scalar
(
"gb_free"
,
gb_free
,
priority
=
1500
,
round
=
1
,
weight
=
0
)
# log stats
logging_output
=
self
.
_reduce_and_log_stats
(
logging_outputs
,
sample_size
,
grad_norm
,
)
# clear CUDA cache to reduce memory fragmentation
if
(
self
.
cuda
and
self
.
args
.
empty_cache_freq
>
0
and
(
(
self
.
get_num_updates
()
+
self
.
args
.
empty_cache_freq
-
1
)
%
self
.
args
.
empty_cache_freq
)
==
0
):
torch
.
cuda
.
empty_cache
()
if
self
.
args
.
fp16
:
metrics
.
log_scalar
(
"loss_scale"
,
self
.
optimizer
.
scaler
.
loss_scale
,
priority
=
700
,
round
=
4
,
weight
=
0
,
)
metrics
.
log_stop_time
(
"train_wall"
)
return
logging_output
@
metrics
.
aggregate
(
"valid"
)
def
valid_step
(
self
,
sample
,
raise_oom
=
False
):
"""Do forward pass in evaluation mode."""
with
torch
.
no_grad
():
self
.
model
.
eval
()
self
.
loss
.
eval
()
sample
,
is_dummy_batch
=
self
.
_prepare_sample
(
sample
)
try
:
_loss
,
sample_size
,
logging_output
=
self
.
task
.
valid_step
(
sample
,
self
.
model
,
self
.
loss
)
except
RuntimeError
as
e
:
if
"out of memory"
in
str
(
e
):
self
.
_log_oom
(
e
)
if
not
raise_oom
:
logger
.
warning
(
"ran out of memory in validation step, retrying batch"
)
for
p
in
self
.
model
.
parameters
():
if
p
.
grad
is
not
None
:
p
.
grad
=
None
# free some memory
if
self
.
cuda
:
torch
.
cuda
.
empty_cache
()
return
self
.
valid_step
(
sample
,
raise_oom
=
True
)
raise
e
logging_outputs
=
[
logging_output
]
if
is_dummy_batch
:
if
torch
.
is_tensor
(
sample_size
):
sample_size
.
zero_
()
else
:
sample_size
*=
0.0
# gather logging outputs from all replicas
if
self
.
data_parallel_world_size
>
1
:
logging_outputs
,
(
sample_size
,)
=
self
.
_aggregate_logging_outputs
(
logging_outputs
,
sample_size
,
ignore
=
is_dummy_batch
,
is_train
=
False
,
)
return
logging_outputs
def
zero_grad
(
self
):
self
.
optimizer
.
zero_grad
()
def
lr_step_begin_epoch
(
self
,
epoch
):
"""Adjust the learning rate at the beginning of the epoch."""
self
.
lr_scheduler
.
step_begin_epoch
(
epoch
)
# prefer updating the LR based on the number of steps
return
self
.
lr_step_update
()
def
lr_step
(
self
,
epoch
,
val_loss
=
None
):
"""Adjust the learning rate at the end of the epoch."""
self
.
lr_scheduler
.
step
(
epoch
,
val_loss
)
# prefer updating the LR based on the number of steps
return
self
.
lr_step_update
()
def
lr_step_update
(
self
):
"""Update the learning rate after each update."""
new_lr
=
self
.
lr_scheduler
.
step_update
(
self
.
get_num_updates
())
if
isinstance
(
new_lr
,
dict
):
for
k
,
v
in
new_lr
.
items
():
metrics
.
log_scalar
(
f
"lr_
{
k
}
"
,
v
,
weight
=
0
,
priority
=
300
)
new_lr
=
new_lr
.
get
(
"default"
,
next
(
iter
(
new_lr
.
values
())))
else
:
metrics
.
log_scalar
(
"lr"
,
new_lr
,
weight
=
0
,
priority
=
300
)
return
new_lr
def
get_lr
(
self
):
"""Get the current learning rate."""
return
self
.
optimizer
.
get_lr
()
def
get_model
(
self
):
"""Get the (non-wrapped) model instance."""
return
self
.
_model
def
get_loss
(
self
):
"""Get the (non-wrapped) loss instance."""
return
self
.
_loss
def
get_num_updates
(
self
):
"""Get the number of parameters updates."""
return
self
.
_num_updates
def
set_num_updates
(
self
,
num_updates
):
"""Set the number of parameters updates."""
self
.
_num_updates
=
num_updates
self
.
lr_step_update
()
metrics
.
log_scalar
(
"num_updates"
,
self
.
_num_updates
,
weight
=
0
,
priority
=
200
)
def
clip_grad_norm
(
self
,
clip_norm
):
return
self
.
optimizer
.
clip_grad_norm
(
clip_norm
)
def
cumulative_training_time
(
self
):
if
self
.
_cumulative_training_time
is
None
:
# single GPU
return
self
.
_local_cumulative_training_time
()
else
:
return
self
.
_cumulative_training_time
def
_local_cumulative_training_time
(
self
):
"""Aggregate training time in seconds."""
return
time
.
time
()
-
self
.
_start_time
+
self
.
_previous_training_time
def
_prepare_sample
(
self
,
sample
,
is_dummy
=
False
):
if
sample
==
"DUMMY"
:
raise
Exception
(
"Trying to use an uninitialized 'dummy' batch. This usually indicates "
"that the total number of batches is smaller than the number of "
"participating GPUs. Try reducing the batch size or using fewer GPUs."
)
if
sample
is
None
or
len
(
sample
)
==
0
:
assert
(
self
.
_dummy_batch
is
not
None
and
len
(
self
.
_dummy_batch
)
>
0
),
"Invalid dummy batch: {}"
.
format
(
self
.
_dummy_batch
)
sample
,
_
=
self
.
_prepare_sample
(
self
.
_dummy_batch
,
is_dummy
=
True
)
return
sample
,
True
if
self
.
cuda
:
sample
=
utils
.
move_to_cuda
(
sample
)
def
apply_half
(
t
):
if
t
.
dtype
is
torch
.
float32
:
return
t
.
half
()
return
t
def
apply_bfloat16
(
t
):
if
t
.
dtype
is
torch
.
float32
:
return
t
.
to
(
dtype
=
torch
.
bfloat16
)
return
t
# Please manually convert data type by yourself.
# if self.args.fp16:
# sample = utils.apply_to_sample(apply_half, sample)
# if self.args.bf16:
# sample = utils.apply_to_sample(apply_bfloat16, sample)
if
self
.
_dummy_batch
==
"DUMMY"
:
self
.
_dummy_batch
=
sample
return
sample
,
False
def
_sync_stats
(
self
):
# Return True if it's using multiple GPUs and DDP or multiple GPUs with
if
self
.
data_parallel_world_size
==
1
:
return
False
else
:
return
True
def
_log_oom
(
self
,
exc
):
msg
=
"OOM: Ran out of memory with exception: {}"
.
format
(
exc
)
logger
.
warning
(
msg
)
if
torch
.
cuda
.
is_available
()
and
hasattr
(
torch
.
cuda
,
"memory_summary"
):
for
device_idx
in
range
(
torch
.
cuda
.
device_count
()):
logger
.
warning
(
torch
.
cuda
.
memory_summary
(
device
=
device_idx
))
sys
.
stderr
.
flush
()
def
_aggregate_logging_outputs
(
self
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]],
*
extra_stats_to_sum
,
ignore
=
False
,
is_train
=
False
,
):
if
self
.
task
.
__class__
.
logging_outputs_can_be_summed
(
self
.
get_loss
(),
is_train
=
is_train
):
return
self
.
_fast_stat_sync_sum
(
logging_outputs
,
*
extra_stats_to_sum
,
ignore
=
ignore
)
else
:
return
self
.
_all_gather_list_sync
(
logging_outputs
,
*
extra_stats_to_sum
,
ignore
=
ignore
)
def
_all_gather_list_sync
(
self
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]],
*
extra_stats_to_sum
,
ignore
=
False
,
):
"""
Sync logging outputs across workers. all_gather_list_sync is
suitable when logging outputs are complex types.
"""
if
ignore
:
logging_outputs
=
[]
results
=
list
(
zip
(
*
distributed_utils
.
all_gather_list
(
[
logging_outputs
]
+
list
(
extra_stats_to_sum
),
max_size
=
getattr
(
self
.
args
,
"all_gather_list_size"
,
16384
),
group
=
self
.
data_parallel_process_group
,
)
)
)
logging_outputs
,
extra_stats_to_sum
=
results
[
0
],
results
[
1
:]
logging_outputs
=
list
(
chain
.
from_iterable
(
logging_outputs
))
extra_stats_to_sum
=
[
sum
(
s
)
for
s
in
extra_stats_to_sum
]
return
logging_outputs
,
extra_stats_to_sum
def
_fast_stat_sync_sum
(
self
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]],
*
extra_stats_to_sum
,
ignore
=
False
,
):
"""
Sync logging outputs across workers. fast_stat_sync_sum is
faster than all_gather_list_sync, but is only suitable when
logging outputs are scalars and can be summed. Note that
*logging_outputs* cannot contain any nested dicts/lists.
"""
data
=
{}
for
i
,
stat
in
enumerate
(
extra_stats_to_sum
):
data
[
"extra_stats_"
+
str
(
i
)]
=
stat
if
len
(
logging_outputs
)
>
0
:
log_keys
=
list
(
logging_outputs
[
0
].
keys
())
for
k
in
log_keys
:
if
not
ignore
:
v
=
sum
(
log
[
k
]
for
log
in
logging_outputs
if
k
in
log
)
else
:
v
=
logging_outputs
[
0
][
k
]
v
=
torch
.
zeros_like
(
v
)
if
torch
.
is_tensor
(
v
)
else
0
data
[
"logging_outputs_"
+
k
]
=
v
else
:
log_keys
=
None
data
=
distributed_utils
.
all_reduce_dict
(
data
,
device
=
self
.
device
,
group
=
self
.
data_parallel_process_group
)
extra_stats_to_sum
=
[
data
[
"extra_stats_"
+
str
(
i
)]
for
i
in
range
(
len
(
extra_stats_to_sum
))
]
if
log_keys
is
not
None
:
logging_outputs
=
[{
k
:
data
[
"logging_outputs_"
+
k
]
for
k
in
log_keys
}]
else
:
logging_outputs
=
[]
return
logging_outputs
,
extra_stats_to_sum
def
_check_grad_norms
(
self
,
grad_norm
):
"""Check that grad norms are consistent across workers."""
if
self
.
_grad_norm_buf
is
not
None
:
self
.
_grad_norm_buf
.
zero_
()
self
.
_grad_norm_buf
[
self
.
data_parallel_rank
]
=
grad_norm
distributed_utils
.
all_reduce
(
self
.
_grad_norm_buf
,
group
=
self
.
data_parallel_process_group
)
def
is_consistent
(
tensor
):
max_abs_diff
=
torch
.
max
(
torch
.
abs
(
tensor
-
tensor
[
0
]))
return
(
torch
.
isfinite
(
tensor
).
all
()
and
(
max_abs_diff
/
(
tensor
[
0
]
+
1e-6
)
<
1e-6
).
all
()
)
if
not
is_consistent
(
self
.
_grad_norm_buf
):
pretty_detail
=
"
\n
"
.
join
(
"rank {:3d} = {:.8f}"
.
format
(
r
,
n
)
for
r
,
n
in
enumerate
(
self
.
_grad_norm_buf
.
tolist
())
)
error_detail
=
"grad_norm across the workers:
\n
{}
\n
"
.
format
(
pretty_detail
)
# use FloatingPointError to trigger NanDetector
raise
FloatingPointError
(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=legacy_ddp. "
"Or are you mixing up different generation of GPUs in training?"
+
"
\n
"
+
"-"
*
80
+
"
\n
{}
\n
"
.
format
(
error_detail
)
+
"-"
*
80
)
def
_reduce_and_log_stats
(
self
,
logging_outputs
,
sample_size
,
grad_norm
=
None
):
if
grad_norm
is
not
None
and
(
not
torch
.
is_tensor
(
grad_norm
)
or
torch
.
isfinite
(
grad_norm
)
):
metrics
.
log_speed
(
"ups"
,
1.0
,
priority
=
100
,
round
=
2
)
metrics
.
log_scalar
(
"gnorm"
,
grad_norm
,
priority
=
400
,
round
=
3
)
if
self
.
args
.
clip_norm
>
0
:
metrics
.
log_scalar
(
"clip"
,
torch
.
where
(
grad_norm
>
self
.
args
.
clip_norm
,
grad_norm
.
new_tensor
(
100
),
grad_norm
.
new_tensor
(
0
),
),
priority
=
500
,
round
=
1
,
)
with
metrics
.
aggregate
()
as
agg
:
if
logging_outputs
is
not
None
:
self
.
task
.
reduce_metrics
(
logging_outputs
,
self
.
get_loss
())
del
logging_outputs
# extra warning for losses that don't properly log a loss value
if
"loss"
not
in
agg
:
if
"loss"
not
in
self
.
_warn_once
:
self
.
_warn_once
.
add
(
"loss"
)
logger
.
warning
(
"Loss.reduce_metrics did not log a 'loss' value, "
"which may break some functionality"
)
metrics
.
log_scalar
(
"loss"
,
-
1
)
logging_output
=
agg
.
get_smoothed_values
()
logging_output
[
"sample_size"
]
=
sample_size
for
key_to_delete
in
[
"ppl"
,
"wps"
,
"wpb"
,
"bsz"
]:
if
key_to_delete
in
logging_output
:
del
logging_output
[
key_to_delete
]
return
logging_output
def
_catalog_shared_params
(
module
,
memo
=
None
,
prefix
=
""
):
if
memo
is
None
:
first_call
=
True
memo
=
{}
else
:
first_call
=
False
for
name
,
param
in
module
.
_parameters
.
items
():
if
param
is
None
:
continue
param_prefix
=
prefix
+
(
"."
if
prefix
else
""
)
+
name
if
param
not
in
memo
:
memo
[
param
]
=
[]
memo
[
param
].
append
(
param_prefix
)
for
name
,
m
in
module
.
_modules
.
items
():
if
m
is
None
:
continue
submodule_prefix
=
prefix
+
(
"."
if
prefix
else
""
)
+
name
_catalog_shared_params
(
m
,
memo
,
submodule_prefix
)
if
first_call
:
return
[
x
for
x
in
memo
.
values
()
if
len
(
x
)
>
1
]
def
_get_module_by_path
(
module
,
path
):
path
=
path
.
split
(
"."
)
for
name
in
path
:
module
=
getattr
(
module
,
name
)
return
module
def
_set_module_by_path
(
module
,
path
,
value
):
path
=
path
.
split
(
"."
)
for
name
in
path
[:
-
1
]:
module
=
getattr
(
module
,
name
)
setattr
(
module
,
path
[
-
1
],
value
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment