Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
7a9fbe67
"examples/pytorch/vrgcn/train_cv_multi_gpu.py" did not exist on "9eaace9216e10790c76e7675741daefa92ae1b59"
Unverified
Commit
7a9fbe67
authored
Feb 03, 2020
by
Olatunji Ruwase
Committed by
GitHub
Feb 03, 2020
Browse files
Add files via upload
parent
dc226fdb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1675 additions
and
0 deletions
+1675
-0
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+1021
-0
deepspeed/pt/deepspeed_lr_schedules.py
deepspeed/pt/deepspeed_lr_schedules.py
+654
-0
No files found.
deepspeed/pt/deepspeed_light.py
0 → 100644
View file @
7a9fbe67
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import
logging
import
torch
import
os
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
tensorboardX
import
SummaryWriter
from
deepspeed.pt.deepspeed_timer
import
ThroughputTimer
,
SynchronizedWallClockTimer
from
deepspeed.pt.deepspeed_zero_optimizer
import
FP16_DeepSpeedZeroOptimizer
from
deepspeed.pt.fp16_optimizer
import
FP16_Optimizer
from
deepspeed.pt.fp16_unfused_optimizer
import
FP16_UnfusedOptimizer
from
deepspeed.pt.deepspeed_fused_lamb
import
FusedLamb
from
deepspeed.pt.deepspeed_config
import
DeepSpeedConfig
,
\
ADAM_OPTIMIZER
,
LAMB_OPTIMIZER
,
DEEPSPEED_OPTIMIZERS
from
deepspeed.pt.deepspeed_dataloader
import
DeepSpeedDataLoader
from
deepspeed.pt.deepspeed_constants
import
ROUTE_TRAIN
,
ROUTE_PREDICT
,
\
ROUTE_EVAL
import
deepspeed.pt.deepspeed_lr_schedules
as
lr_schedules
from
deepspeed.pt.deepspeed_csr_tensor
import
CSRTensor
from
apex
import
amp
from
apex.optimizers.fused_adam
import
FusedAdam
MEMORY_OPT_ALLREDUCE_SIZE
=
500000000
SUMMARY_WRITER_DIR_NAME
=
"JobId"
try
:
from
apex_C
import
flatten
from
apex_C
import
unflatten
except
ImportError
:
try
:
_
=
warned_flatten
except
NameError
:
print
(
"Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten
=
True
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
torch._utils
import
_unflatten_dense_tensors
as
unflatten
def
split_half_float_double_csr
(
tensors
):
dtypes
=
[
"torch.cuda.HalfTensor"
,
"torch.cuda.FloatTensor"
,
"torch.cuda.DoubleTensor"
,
CSRTensor
.
type
()
]
buckets
=
[]
for
i
,
dtype
in
enumerate
(
dtypes
):
bucket
=
[
t
for
t
in
tensors
if
t
.
type
()
==
dtype
]
if
bucket
:
buckets
.
append
((
dtype
,
bucket
))
return
buckets
def
_initialize_parameter_parallel_groups
(
parameter_parallel_size
=
None
):
data_parallel_size
=
int
(
dist
.
get_world_size
())
if
parameter_parallel_size
is
None
:
parameter_parallel_size
=
int
(
data_parallel_size
)
print
(
data_parallel_size
,
parameter_parallel_size
)
assert
data_parallel_size
%
parameter_parallel_size
==
0
,
\
'world size should be divisible by parameter parallel size'
rank
=
dist
.
get_rank
()
my_group
=
None
for
i
in
range
(
dist
.
get_world_size
()
//
parameter_parallel_size
):
ranks
=
range
(
i
*
parameter_parallel_size
,
(
i
+
1
)
*
parameter_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
my_group
=
group
return
my_group
def
print_configuration
(
args
,
name
):
print
(
'{}:'
.
format
(
name
),
flush
=
True
)
for
arg
in
sorted
(
vars
(
args
)):
dots
=
'.'
*
(
29
-
len
(
arg
))
print
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)),
flush
=
True
)
class
DeepSpeedLight
(
Module
):
r
"""DeepSpeed engine for training.
"""
def
__init__
(
self
,
args
,
model
,
optimizer
=
None
,
model_parameters
=
None
,
training_data
=
None
,
lr_scheduler
=
None
,
mpu
=
None
,
dist_init_required
=
True
,
collate_fn
=
None
):
super
(
DeepSpeedLight
,
self
).
__init__
()
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"[%(levelname)s %(asctime)s] %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
)
self
.
client_optimizer
=
optimizer
self
.
client_model_parameters
=
model_parameters
self
.
client_lr_scheduler
=
lr_scheduler
self
.
training_data
=
training_data
self
.
collate_fn
=
collate_fn
self
.
mpu
=
mpu
self
.
data_parallel_group
=
None
self
.
global_steps
=
0
self
.
micro_steps
=
0
self
.
skipped_steps
=
0
self
.
gradient_predivide_factor
=
1.0
self
.
gradient_average
=
True
self
.
warn_unscaled_loss
=
True
if
dist_init_required
:
dist
.
init_process_group
(
backend
=
"nccl"
)
self
.
_do_args_sanity_check
(
args
)
self
.
_configure_with_arguments
(
args
,
mpu
)
self
.
_do_sanity_check
()
self
.
sample_count
=
0
if
self
.
tensorboard_enabled
():
self
.
summary_writer
=
self
.
get_summary_writer
()
self
.
_init_distributed
(
dist_init_required
)
# Throughput timer
self
.
tput_timer
=
ThroughputTimer
(
batch_size
=
self
.
train_micro_batch_size_per_gpu
(),
num_workers
=
self
.
world_size
,
monitor_memory
=
False
)
self
.
training_dataloader
=
self
.
deepspeed_io
(
training_data
)
if
training_data
else
None
# Configure distributed model
self
.
_configure_distributed_model
(
model
)
# Configure optimizer and scheduler
self
.
optimizer
=
None
self
.
lr_scheduler
=
None
if
model_parameters
or
optimizer
:
self
.
_configure_optimizer
(
optimizer
,
model_parameters
)
self
.
_configure_lr_scheduler
(
lr_scheduler
)
self
.
_report_progress
(
0
)
# Configure wall clock timer
self
.
timers
=
SynchronizedWallClockTimer
()
# Bookkeeping for csr support
self
.
csr_tensor_module_names
=
set
()
if
self
.
sparse_gradients_enabled
():
for
name
,
module
in
self
.
module
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Embedding
):
self
.
csr_tensor_module_names
.
add
(
name
)
logging
.
info
(
"Will convert {} to sparse (csr) "
"tensor during training"
.
format
(
name
))
self
.
save_non_zero_checkpoint
=
False
self
.
save_zero_checkpoint
=
False
self
.
_configure_checkpointing
(
dist_init_required
)
if
self
.
global_rank
==
0
:
self
.
_config
.
print
(
'DeepSpeedLight configuration'
)
if
self
.
dump_state
():
print_configuration
(
self
,
'DeepSpeedLight'
)
def
tensorboard_enabled
(
self
):
return
self
.
_config
.
tensorboard_enabled
def
tensorboard_output_path
(
self
):
return
self
.
_config
.
tensorboard_output_path
def
tensorboard_job_name
(
self
):
return
self
.
_config
.
tensorboard_job_name
def
get_summary_writer
(
self
,
name
=
"DeepSpeedJobName"
,
base
=
os
.
environ
[
"HOME"
]
+
"/tensorboard"
):
if
self
.
tensorboard_job_name
():
name
=
self
.
tensorboard_job_name
()
if
self
.
tensorboard_output_path
():
return
SummaryWriter
(
log_dir
=
self
.
tensorboard_output_path
())
if
'DLWS_JOB_ID'
in
os
.
environ
:
SUMMARY_WRITER_DIR_NAME
=
os
.
environ
[
'DLWS_JOB_ID'
]
+
"/logs"
return
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
base
,
SUMMARY_WRITER_DIR_NAME
,
name
))
def
wall_clock_breakdown
(
self
):
return
self
.
_config
.
wall_clock_breakdown
def
sparse_gradients_enabled
(
self
):
return
self
.
_config
.
sparse_gradients_enabled
def
train_batch_size
(
self
):
return
self
.
_config
.
train_batch_size
def
train_micro_batch_size_per_gpu
(
self
):
return
self
.
_config
.
train_micro_batch_size_per_gpu
def
optimizer_name
(
self
):
return
self
.
_config
.
optimizer_name
def
optimizer_params
(
self
):
return
self
.
_config
.
optimizer_params
def
scheduler_name
(
self
):
return
self
.
_config
.
scheduler_name
def
scheduler_params
(
self
):
return
self
.
_config
.
scheduler_params
def
zero_optimization
(
self
):
return
self
.
_config
.
zero_enabled
def
allgather_size
(
self
):
return
self
.
_config
.
allgather_size
def
fp16_enabled
(
self
):
return
self
.
_config
.
fp16_enabled
def
loss_scale
(
self
):
return
self
.
_config
.
loss_scale
def
gradient_accumulation_steps
(
self
):
return
self
.
_config
.
gradient_accumulation_steps
def
allreduce_always_fp32
(
self
):
return
self
.
_config
.
allreduce_always_fp32
def
postscale_gradients
(
self
):
return
not
self
.
_config
.
prescale_gradients
def
steps_per_print
(
self
):
return
self
.
_config
.
steps_per_print
def
disable_allgather
(
self
):
return
self
.
_config
.
disable_allgather
def
dump_state
(
self
):
return
self
.
_config
.
dump_state
def
gradient_clipping
(
self
):
return
self
.
_config
.
gradient_clipping
def
dynamic_loss_scale
(
self
):
return
self
.
_config
.
loss_scale
==
0
def
initial_dynamic_scale
(
self
):
return
self
.
_config
.
initial_dynamic_scale
def
dynamic_loss_scale_args
(
self
):
return
self
.
_config
.
dynamic_loss_scale_args
def
_configure_lr_scheduler
(
self
,
client_lr_scheduler
):
# First check for scheduler in json configuration
lr_scheduler
=
self
.
_scheduler_from_config
(
self
.
optimizer
)
if
lr_scheduler
:
logging
.
info
(
f
'DeepSpeed using configured LR scheduler =
{
self
.
scheduler_name
()
}
'
)
self
.
lr_scheduler
=
lr_scheduler
else
:
logging
.
warning
(
'DeepSpeed using client LR scheduler'
)
self
.
lr_scheduler
=
client_lr_scheduler
logging
.
info
(
f
'DeepSpeed LR Scheduler =
{
self
.
lr_scheduler
}
'
)
def
_configure_checkpointing
(
self
,
dist_init_required
):
dp_rank
=
torch
.
distributed
.
get_rank
(
)
if
self
.
mpu
is
None
else
self
.
mpu
.
get_data_parallel_rank
()
#only the first data parallel process needs to store the model checkpoint
self
.
save_non_zero_checkpoint
=
True
if
dp_rank
==
0
else
False
if
self
.
zero_optimization
():
pp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
optimizer
.
dp_process_group
)
#only the first parameter parallel process needs to store the optimizer state checkpoints for zero
self
.
save_zero_checkpoint
=
True
if
pp_rank
==
dp_rank
else
False
def
_scheduler_from_config
(
self
,
optimizer
):
scheduler_name
=
self
.
scheduler_name
()
if
scheduler_name
is
not
None
:
if
hasattr
(
lr_schedules
,
scheduler_name
):
scheduler
=
getattr
(
lr_schedules
,
scheduler_name
)
else
:
assert
hasattr
(
torch
.
optim
.
lr_scheduler
,
scheduler_name
),
\
f
"DeepSpeed does not recognize LR scheduler
{
scheduler_name
}
"
scheduler
=
getattr
(
torch
.
optim
.
lr_scheduler
,
scheduler_name
)
scheduler_params
=
self
.
scheduler_params
()
instantiated_scheduler
=
scheduler
(
optimizer
,
**
scheduler_params
)
return
instantiated_scheduler
else
:
return
None
def
_init_distributed
(
self
,
dist_init_required
):
if
self
.
local_rank
>=
0
:
torch
.
cuda
.
set_device
(
self
.
local_rank
)
self
.
device
=
torch
.
device
(
"cuda"
,
self
.
local_rank
)
self
.
world_size
=
dist
.
get_world_size
()
self
.
global_rank
=
dist
.
get_rank
()
logging
.
info
(
"Set device to local rank {} within node."
.
format
(
self
.
local_rank
))
else
:
self
.
world_size
=
1
self
.
global_rank
=
0
self
.
device
=
torch
.
device
(
"cuda"
)
# Configure based on command line arguments
def
_configure_with_arguments
(
self
,
args
,
mpu
):
self
.
local_rank
=
args
.
local_rank
if
hasattr
(
args
,
'local_rank'
)
else
0
self
.
_config
=
DeepSpeedConfig
(
args
.
deepspeed_config
,
mpu
)
# Validate command line arguments
def
_do_args_sanity_check
(
self
,
args
):
assert
hasattr
(
args
,
'local_rank'
)
and
type
(
args
.
local_rank
)
==
int
,
\
'DeepSpeed requires integer command line parameter --local_rank'
assert
hasattr
(
args
,
'deepspeed_config'
)
and
args
.
deepspeed_config
is
not
None
,
\
'DeepSpeed requires --deepspeed_config to specify configuration file'
assert
os
.
path
.
isfile
(
args
.
deepspeed_config
),
\
'DeepSpeed configuration file: {} is not an existing file'
.
format
(
args
.
deepspeed_config
)
def
_is_supported_optimizer
(
self
,
optimizer_name
):
return
optimizer_name
in
DEEPSPEED_OPTIMIZERS
or
\
getattr
(
torch
.
optim
,
optimizer_name
,
None
)
is
not
None
# Validate configuration based on command line arguments
def
_do_sanity_check
(
self
):
if
not
self
.
client_optimizer
:
assert
self
.
_is_supported_optimizer
(
self
.
optimizer_name
()),
\
'{} is not a supported DeepSpeed Optimizer'
.
format
(
self
.
optimizer_name
())
assert
self
.
client_model_parameters
,
\
'DeepSpeed {} optimizer requires parameters in initialize() call'
.
format
(
self
.
optimizer_name
())
if
self
.
optimizer_name
()
==
LAMB_OPTIMIZER
:
assert
self
.
dynamic_loss_scale
(),
\
'DeepSpeed {} optimizer requires dynamic loss scaling'
.
format
(
self
.
optimizer_name
())
def
_configure_distributed_model
(
self
,
model
):
self
.
module
=
model
if
self
.
fp16_enabled
():
self
.
module
.
half
()
self
.
module
.
to
(
self
.
device
)
if
self
.
mpu
is
None
:
self
.
data_parallel_group
=
_initialize_parameter_parallel_groups
()
self
.
dp_world_size
=
dist
.
get_world_size
()
src_rank
=
0
else
:
self
.
data_parallel_group
=
self
.
mpu
.
get_data_parallel_group
()
self
.
dp_world_size
=
self
.
mpu
.
get_data_parallel_world_size
()
src_rank
=
self
.
mpu
.
get_model_parallel_rank
()
for
p
in
self
.
module
.
parameters
():
if
torch
.
is_tensor
(
p
):
dist
.
broadcast
(
p
,
src_rank
,
group
=
self
.
data_parallel_group
)
# TODO: support new AMP optimizer
# self.module.half()
# self.module.to(self.local_rank)
#self.module, self.optimizer = amp.initialize(self.module, self.optimizer, opt_level="O2")
# Configure optimizer
def
_configure_optimizer
(
self
,
client_optimizer
,
model_parameters
):
if
client_optimizer
is
not
None
:
basic_optimizer
=
client_optimizer
logging
.
info
(
'Using client Optimizer as basic optimizer'
)
else
:
basic_optimizer
=
self
.
_configure_basic_optimizer
(
model_parameters
)
logging
.
info
(
'Using DeepSpeed Optimizer param name {} as basic optimizer'
.
format
(
self
.
optimizer_name
()))
logging
.
info
(
'DeepSpeed Basic Optimizer = {}'
.
format
(
basic_optimizer
))
if
self
.
zero_optimization
()
and
self
.
optimizer_name
()
==
ADAM_OPTIMIZER
:
self
.
optimizer
=
self
.
_configure_zero_optimizer
(
basic_optimizer
)
elif
self
.
fp16_enabled
():
self
.
optimizer
=
self
.
_configure_fp16_optimizer
(
basic_optimizer
)
else
:
self
.
optimizer
=
basic_optimizer
# logging.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
def
_configure_basic_optimizer
(
self
,
model_parameters
):
optimizer_parameters
=
self
.
optimizer_params
()
if
self
.
fp16_enabled
()
and
'max_grad_norm'
in
optimizer_parameters
.
keys
():
optimizer_parameters
[
'max_grad_norm'
]
=
0.0
if
self
.
optimizer_name
()
==
ADAM_OPTIMIZER
:
optimizer
=
FusedAdam
(
model_parameters
,
**
optimizer_parameters
)
elif
self
.
optimizer_name
()
==
LAMB_OPTIMIZER
:
optimizer
=
FusedLamb
(
model_parameters
,
**
optimizer_parameters
)
else
:
torch_optimizer
=
getattr
(
torch
.
optim
,
self
.
optimizer_name
())
optimizer
=
torch_optimizer
(
model_parameters
,
**
optimizer_parameters
)
return
optimizer
def
_configure_fp16_optimizer
(
self
,
optimizer
):
initial_dynamic_scale
=
self
.
initial_dynamic_scale
()
dynamic_loss_args
=
self
.
dynamic_loss_scale_args
()
clip_grad
=
self
.
gradient_clipping
()
if
self
.
optimizer_name
()
==
ADAM_OPTIMIZER
:
if
self
.
dynamic_loss_scale
():
logging
.
info
(
'Creating fp16 optimizer with dynamic loss scale'
)
optimizer
=
FP16_Optimizer
(
optimizer
,
dynamic_loss_scale
=
True
,
initial_dynamic_scale
=
initial_dynamic_scale
,
dynamic_loss_args
=
dynamic_loss_args
,
mpu
=
self
.
mpu
,
clip_grad
=
clip_grad
,
fused_adam_legacy
=
True
)
else
:
logging
.
info
(
'Creating fp16 optimizer with static loss scale: {}'
.
format
(
self
.
loss_scale
()))
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
self
.
loss_scale
(),
mpu
=
self
.
mpu
,
clip_grad
=
clip_grad
,
fused_adam_legacy
=
True
)
else
:
logging
.
info
(
'Creating fp16 unfused optimizer with dynamic loss scale'
)
optimizer
=
FP16_UnfusedOptimizer
(
optimizer
,
dynamic_loss_scale
=
self
.
dynamic_loss_scale
(),
dynamic_loss_args
=
dynamic_loss_args
,
mpu
=
self
.
mpu
,
clip_grad
=
clip_grad
,
fused_lamb_legacy
=
True
if
self
.
optimizer_name
()
==
LAMB_OPTIMIZER
else
False
)
return
optimizer
def
_configure_zero_optimizer
(
self
,
optimizer
):
logging
.
info
(
'Creating fp16 zero optimizer'
)
optimizer
=
FP16_DeepSpeedZeroOptimizer
(
optimizer
,
static_loss_scale
=
self
.
loss_scale
(),
dynamic_loss_scale
=
self
.
dynamic_loss_scale
(),
dynamic_loss_args
=
self
.
dynamic_loss_scale_args
(),
dp_process_group
=
self
.
data_parallel_group
,
clip_grad
=
self
.
gradient_clipping
(),
all_gather_partitions
=
not
self
.
disable_allgather
(),
allgather_size
=
self
.
allgather_size
(),
mpu
=
self
.
mpu
)
return
optimizer
def
deepspeed_io
(
self
,
dataset
,
batch_size
=
None
,
route
=
ROUTE_TRAIN
,
pin_memory
=
True
,
data_sampler
=
None
,
collate_fn
=
None
,
num_local_io_workers
=
None
):
if
not
isinstance
(
dataset
,
torch
.
utils
.
data
.
Dataset
):
raise
ValueError
(
"Training data must be a torch Dataset"
)
if
data_sampler
is
None
and
(
route
==
ROUTE_PREDICT
or
route
==
ROUTE_EVAL
):
data_sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
if
batch_size
is
None
:
batch_size
=
self
.
train_micro_batch_size_per_gpu
()
if
collate_fn
is
None
:
collate_fn
=
self
.
collate_fn
# Currently we only use timer in train route
deepspeed_io_timer
=
None
if
route
==
ROUTE_TRAIN
:
deepspeed_io_timer
=
self
.
tput_timer
return
DeepSpeedDataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
pin_memory
=
pin_memory
,
collate_fn
=
collate_fn
,
local_rank
=
self
.
local_rank
,
tput_timer
=
deepspeed_io_timer
,
num_local_io_workers
=
num_local_io_workers
,
data_sampler
=
data_sampler
)
def
train
(
self
):
r
"""
"""
self
.
warn_unscaled_loss
=
True
self
.
module
.
train
()
def
eval
(
self
):
r
"""
"""
self
.
warn_unscaled_loss
=
True
self
.
module
.
train
(
False
)
def
_scale_loss
(
self
,
loss
):
if
isinstance
(
loss
,
torch
.
Tensor
):
loss
=
loss
/
self
.
gradient_accumulation_steps
()
elif
isinstance
(
loss
,
tuple
)
and
isinstance
(
loss
[
0
],
torch
.
Tensor
):
loss
=
(
l
/
self
.
gradient_accumulation_steps
()
for
l
in
loss
)
elif
isinstance
(
loss
,
list
)
and
isinstance
(
loss
[
0
],
torch
.
Tensor
):
loss
=
[
l
/
self
.
gradient_accumulation_steps
()
for
l
in
loss
]
else
:
if
self
.
warn_unscaled_loss
:
logging
.
warning
(
f
'DeepSpeed unable to scale loss because of type:
{
type
(
loss
)
}
'
)
self
.
warn_unscaled_loss
=
False
return
loss
def
forward
(
self
,
*
inputs
,
**
kwargs
):
r
"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'forward_microstep'
).
start
()
self
.
timers
(
'forward'
).
start
()
if
self
.
training_dataloader
is
None
:
self
.
tput_timer
.
start
()
loss
=
self
.
module
(
*
inputs
,
**
kwargs
)
# scale loss w.r.t. gradient accumulation if needed
if
self
.
gradient_accumulation_steps
()
>
1
:
loss
=
self
.
_scale_loss
(
loss
)
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'forward'
).
stop
()
self
.
timers
(
'forward_microstep'
).
stop
()
return
loss
def
allreduce_gradients
(
self
,
bucket_size
=
MEMORY_OPT_ALLREDUCE_SIZE
):
if
self
.
is_gradient_accumulation_boundary
():
self
.
buffered_allreduce_fallback
(
elements_per_buffer
=
bucket_size
)
def
backward
(
self
,
loss
,
allreduce_gradients
=
True
):
r
"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
if
self
.
is_gradient_accumulation_boundary
()
and
self
.
tensorboard_enabled
(
)
and
torch
.
distributed
.
get_rank
(
)
==
0
:
# deepspeed tensorboard support for loss
self
.
sample_count
+=
(
self
.
train_micro_batch_size_per_gpu
()
*
torch
.
distributed
.
get_world_size
()
*
self
.
gradient_accumulation_steps
())
self
.
summary_events
=
[
(
f
'Train/Samples/train_loss'
,
loss
.
mean
().
item
()
*
self
.
gradient_accumulation_steps
(),
self
.
sample_count
)
]
for
event
in
self
.
summary_events
:
# write_summary_events
self
.
summary_writer
.
add_scalar
(
event
[
0
],
event
[
1
],
event
[
2
])
self
.
summary_writer
.
flush
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_microstep'
).
start
()
self
.
timers
(
'backward'
).
start
()
assert
self
.
optimizer
is
not
None
,
"must provide optimizer during "
\
"init in order to use backward"
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_inner_microstep'
).
start
()
self
.
timers
(
'backward_inner'
).
start
()
if
self
.
zero_optimization
():
self
.
optimizer
.
backward
(
loss
)
elif
self
.
fp16_enabled
():
self
.
optimizer
.
backward
(
loss
)
# TODO: Use new AMP semantics as below
# with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# scaled_loss.backward()
else
:
loss
.
backward
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_inner'
).
stop
()
self
.
timers
(
'backward_inner_microstep'
).
stop
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_allreduce_microstep'
).
start
()
self
.
timers
(
'backward_allreduce'
).
start
()
if
allreduce_gradients
:
self
.
allreduce_gradients
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_allreduce'
).
stop
()
self
.
timers
(
'backward_allreduce_microstep'
).
stop
()
self
.
timers
(
'backward'
).
stop
()
self
.
timers
(
'backward_microstep'
).
stop
()
def
is_gradient_accumulation_boundary
(
self
):
return
(
self
.
micro_steps
+
1
)
%
\
self
.
gradient_accumulation_steps
()
==
0
def
step
(
self
):
r
"""Execute the weight update step after forward and backward propagation on effective_train_batch
"""
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'step_microstep'
).
start
()
self
.
timers
(
'step'
).
start
()
assert
self
.
optimizer
is
not
None
,
"must provide optimizer during "
\
"init in order to use step"
report_progress
=
self
.
global_rank
==
0
if
self
.
global_rank
else
True
if
self
.
is_gradient_accumulation_boundary
():
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
# Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow
=
False
if
hasattr
(
self
.
optimizer
,
'overflow'
):
overflow
=
self
.
optimizer
.
overflow
if
overflow
:
self
.
skipped_steps
+=
1
else
:
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
step
()
if
report_progress
and
(
self
.
global_steps
+
1
)
%
self
.
steps_per_print
()
==
0
:
self
.
_report_progress
(
self
.
global_steps
+
1
)
self
.
global_steps
+=
1
self
.
tput_timer
.
stop
(
report_progress
)
if
self
.
is_gradient_accumulation_boundary
()
and
self
.
tensorboard_enabled
(
)
and
torch
.
distributed
.
get_rank
()
==
0
:
# deepspeed tensorboard support for lr
self
.
summary_events
=
[(
f
'Train/Samples/lr'
,
self
.
get_lr
()[
0
],
self
.
sample_count
)]
for
event
in
self
.
summary_events
:
# write_summary_events
self
.
summary_writer
.
add_scalar
(
event
[
0
],
event
[
1
],
event
[
2
])
self
.
summary_writer
.
flush
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'step'
).
stop
()
self
.
timers
(
'step_microstep'
).
stop
()
self
.
timers
.
log
([
'forward_microstep'
,
'backward_microstep'
,
'backward_inner_microstep'
,
'backward_allreduce_microstep'
,
'step_microstep'
])
if
self
.
is_gradient_accumulation_boundary
():
if
self
.
tensorboard_enabled
()
and
torch
.
distributed
.
get_rank
(
)
==
0
:
# this is done before the log because log resets timers
self
.
summary_events
=
[(
f
'Train/Samples/elapsed_time_ms_forward'
,
self
.
timers
(
'forward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward'
,
self
.
timers
(
'backward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward_inner'
,
self
.
timers
(
'backward_inner'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward_allreduce'
,
self
.
timers
(
'backward_allreduce'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_step'
,
self
.
timers
(
'step'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
)
]
for
event
in
self
.
summary_events
:
# write_summary_events
self
.
summary_writer
.
add_scalar
(
event
[
0
],
event
[
1
],
event
[
2
])
self
.
summary_writer
.
flush
()
self
.
timers
.
log
([
'forward'
,
'backward'
,
'backward_inner'
,
'backward_allreduce'
,
'step'
])
self
.
micro_steps
+=
1
def
_get_optimizer_param
(
self
,
param_name
):
result
=
[]
if
not
self
.
optimizer
:
return
result
for
group
in
self
.
optimizer
.
param_groups
:
if
param_name
in
group
:
result
.
append
(
group
[
param_name
])
else
:
result
.
append
(
0.0
)
return
result
def
get_lr
(
self
):
return
self
.
_get_optimizer_param
(
'lr'
)
def
get_mom
(
self
):
return
self
.
_get_optimizer_param
(
'betas'
)
def
_report_progress
(
self
,
step
):
lr
=
self
.
get_lr
()
mom
=
self
.
get_mom
()
logging
.
info
(
'rank:{} step={}, skipped={}, lr={}, mom={}'
.
format
(
self
.
global_rank
,
step
,
self
.
skipped_steps
,
lr
,
mom
))
def
allreduce_bucket
(
self
,
bucket
):
tensor
=
flatten
(
bucket
)
tensor_to_allreduce
=
tensor
if
self
.
allreduce_always_fp32
():
tensor_to_allreduce
=
tensor
.
float
()
if
self
.
postscale_gradients
():
if
self
.
gradient_predivide_factor
!=
1.0
:
tensor_to_allreduce
.
mul_
(
1.
/
self
.
gradient_predivide_factor
)
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
data_parallel_group
)
if
self
.
gradient_average
:
if
self
.
gradient_predivide_factor
!=
self
.
dp_world_size
:
tensor_to_allreduce
.
mul_
(
self
.
gradient_predivide_factor
/
self
.
dp_world_size
)
else
:
tensor_to_allreduce
.
div_
(
self
.
dp_world_size
)
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
data_parallel_group
)
if
self
.
allreduce_always_fp32
()
and
tensor
is
not
tensor_to_allreduce
:
tensor
.
copy_
(
tensor_to_allreduce
)
return
tensor
def
allreduce_and_copy
(
self
,
small_bucket
):
allreduced
=
self
.
allreduce_bucket
(
small_bucket
)
for
buf
,
synced
in
zip
(
small_bucket
,
unflatten
(
allreduced
,
small_bucket
)):
buf
.
copy_
(
synced
)
def
allreduce_no_retain
(
self
,
bucket
,
numel_per_bucket
=
500000000
):
small_bucket
=
[]
numel
=
0
for
tensor
in
bucket
:
small_bucket
.
append
(
tensor
)
numel
=
numel
+
tensor
.
numel
()
if
numel
>
numel_per_bucket
:
self
.
allreduce_and_copy
(
small_bucket
)
small_bucket
=
[]
if
len
(
small_bucket
)
>
0
:
self
.
allreduce_and_copy
(
small_bucket
)
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
grads
=
[]
for
param_name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
grad
is
not
None
:
grad_data
=
param
.
grad
.
data
param_name_root
=
param_name
.
split
(
'.'
,
1
)[
0
]
if
self
.
sparse_gradients_enabled
(
)
and
param_name_root
in
self
.
csr_tensor_module_names
:
grads
.
append
(
CSRTensor
(
grad_data
))
else
:
grads
.
append
(
grad_data
)
split_buckets
=
split_half_float_double_csr
(
grads
)
for
i
,
bucket_tuple
in
enumerate
(
split_buckets
):
bucket_type
,
bucket
=
bucket_tuple
if
bucket_type
==
CSRTensor
.
type
():
self
.
csr_allreduce_no_retain
(
bucket
)
else
:
self
.
allreduce_no_retain
(
bucket
,
numel_per_bucket
=
elements_per_buffer
)
def
csr_allreduce_no_retain
(
self
,
bucket
):
allreduced_csrs
=
self
.
csr_allreduce_bucket
(
bucket
)
# Densify csr tensor and copy back to original location
for
csr
in
allreduced_csrs
:
dense_tensor
=
csr
.
to_dense
()
csr
.
orig_dense_tensor
.
copy_
(
dense_tensor
)
def
csr_allreduce_bucket
(
self
,
bucket
):
csr_list
=
[]
for
csr
in
bucket
:
csr_list
.
append
(
self
.
csr_allreduce
(
csr
))
return
csr_list
def
csr_allreduce
(
self
,
csr
):
# Pre-divide for fp16 stability
csr
.
values
.
div_
(
self
.
dp_world_size
)
indices_device_list
=
self
.
csr_all_gather
(
csr
.
indices
)
values_device_list
=
self
.
csr_all_gather
(
csr
.
values
)
csr
.
indices
=
torch
.
cat
(
indices_device_list
)
csr
.
values
=
torch
.
cat
(
values_device_list
)
return
csr
def
csr_all_gather
(
self
,
value
):
my_size
=
torch
.
LongTensor
([
value
.
size
()[
0
]]).
cuda
()
all_sizes
=
self
.
all_gather_scalar
(
my_size
)
max_size
=
torch
.
cat
(
all_sizes
).
max
()
fill_size
=
(
max_size
-
my_size
)
assert
value
.
dim
()
in
[
1
,
2
]
if
value
.
dim
()
==
1
:
if
fill_size
>
0
:
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
)])
tensor_list
=
[
value
.
new_zeros
(
max_size
)
for
_
in
range
(
dist
.
get_world_size
())
]
else
:
if
fill_size
>
0
:
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
,
value
.
size
()[
1
])])
tensor_list
=
[
value
.
new_zeros
(
max_size
,
value
.
size
()[
1
])
for
_
in
range
(
dist
.
get_world_size
())
]
dist
.
all_gather
(
tensor_list
,
value
,
group
=
self
.
data_parallel_group
)
tensors
=
[]
for
dev_idx
,
t
in
enumerate
(
tensor_list
):
size
=
all_sizes
[
dev_idx
][
0
]
tensors
.
append
(
t
.
index_select
(
0
,
torch
.
LongTensor
(
range
(
size
)).
cuda
()))
return
tensors
def
all_gather_scalar
(
self
,
value
):
tensor_list
=
[
value
.
new_zeros
(
value
.
size
())
for
_
in
range
(
self
.
dp_world_size
)]
dist
.
all_gather
(
tensor_list
,
value
,
group
=
self
.
data_parallel_group
)
return
tensor_list
def
module_state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
sd
def
load_module_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
def
_get_zero_ckpt_name
(
self
,
checkpoints_path
,
tag
):
mp_rank
=
0
if
self
.
mpu
is
None
else
self
.
mpu
.
get_model_parallel_rank
()
pp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
optimizer
.
dp_process_group
)
filename
=
'zero_pp_rank_{}'
.
format
(
pp_rank
)
zero_ckpt_name
=
os
.
path
.
join
(
checkpoints_path
,
str
(
tag
),
filename
+
'_mp_rank_{:02d}'
.
format
(
mp_rank
)
+
'optim_states.pt'
)
return
zero_ckpt_name
def
_get_ckpt_name
(
self
,
checkpoints_path
,
tag
):
mp_rank
=
0
if
self
.
mpu
is
None
else
self
.
mpu
.
get_model_parallel_rank
()
ckpt_name
=
os
.
path
.
join
(
checkpoints_path
,
str
(
tag
),
'mp_rank_{:02d}'
.
format
(
mp_rank
)
+
'_model_states.pt'
)
return
ckpt_name
def
_ensure_directory_exists
(
self
,
filename
):
dirname
=
os
.
path
.
dirname
(
filename
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
def
load_checkpoint
(
self
,
load_dir
,
tag
):
r
"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path
,
client_states
=
self
.
_load_checkpoint
(
load_dir
,
tag
)
if
self
.
zero_optimization
()
and
load_path
is
not
None
:
self
.
_load_zero_checkpoint
(
load_dir
,
tag
)
return
load_path
,
client_states
def
_load_checkpoint
(
self
,
load_dir
,
tag
):
load_path
=
self
.
_get_ckpt_name
(
load_dir
,
tag
)
if
not
os
.
path
.
exists
(
load_path
):
logging
.
warn
(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.
format
(
load_path
))
return
None
,
None
logging
.
info
(
'Loading checkpoint: {}'
.
format
(
load_path
))
checkpoint
=
torch
.
load
(
load_path
,
map_location
=
lambda
storage
,
loc
:
storage
)
self
.
load_module_state_dict
(
checkpoint
[
'module'
])
if
not
self
.
zero_optimization
():
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
self
.
csr_tensor_module_names
=
checkpoint
[
'csr_tensor_module_names'
]
self
.
global_steps
=
checkpoint
[
'global_steps'
]
self
.
skipped_steps
=
checkpoint
[
'skipped_steps'
]
deepspeed_states
=
[
'module'
,
'optimizer'
,
'csr_tensor_module_names'
,
'skipped_steps'
,
'global_step'
]
client_state
=
{
key
:
value
for
key
,
value
in
checkpoint
.
items
()
if
not
key
in
deepspeed_states
}
return
load_path
,
client_state
def
_load_zero_checkpoint
(
self
,
load_dir
,
tag
):
zero_checkpoint_name
=
self
.
_get_zero_ckpt_name
(
load_dir
,
tag
)
if
not
os
.
path
.
exists
(
zero_checkpoint_name
):
logging
.
warn
(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.
format
(
zero_checkpoint_name
))
return
None
zero_sd
=
torch
.
load
(
zero_checkpoint_name
,
map_location
=
'cpu'
)
self
.
optimizer
.
load_state_dict
(
zero_sd
[
'optimizer_state_dict'
])
logging
.
info
(
'loading zero checkpoint {}'
.
format
(
zero_checkpoint_name
))
def
save_checkpoint
(
self
,
save_dir
,
tag
,
client_state
=
{}):
r
"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
client_state: Optional. State dictionary used for saving required training states in the client code.
"""
#This is to make sure the checkpoint names are created without collision
#There seems to be issue creating them in parallel
self
.
_create_checkpoint_files
(
save_dir
,
tag
)
try
:
if
self
.
save_non_zero_checkpoint
:
self
.
_save_checkpoint
(
save_dir
,
tag
,
client_state
=
client_state
)
if
self
.
save_zero_checkpoint
:
self
.
_save_zero_checkpoint
(
save_dir
,
tag
)
except
:
logging
.
error
(
f
'Failed Saving model checkpoint to
{
save_dir
}
with tag
{
tag
}
'
)
return
False
return
True
def
_create_checkpoint_files
(
self
,
save_dir
,
tag
):
#checkpoint files are created sequentially
for
rank
in
range
(
dist
.
get_world_size
()):
if
rank
==
dist
.
get_rank
():
try
:
if
self
.
save_non_zero_checkpoint
:
checkpoint_name
=
self
.
_get_ckpt_name
(
save_dir
,
tag
)
self
.
_ensure_directory_exists
(
checkpoint_name
)
if
self
.
save_zero_checkpoint
:
checkpoint_name
=
self
.
_get_zero_ckpt_name
(
save_dir
,
tag
)
self
.
_ensure_directory_exists
(
checkpoint_name
)
except
:
logging
.
error
(
f
'Failed Saving model checkpoint to
{
save_dir
}
with tag
{
tag
}
'
)
return
False
dist
.
barrier
()
def
_save_checkpoint
(
self
,
save_dir
,
tag
,
client_state
=
{}):
save_path
=
self
.
_get_ckpt_name
(
save_dir
,
tag
)
#self._ensure_directory_exists(save_path)
state
=
{
'module'
:
self
.
module_state_dict
(),
'optimizer'
:
self
.
optimizer
.
state_dict
()
if
self
.
optimizer
and
not
self
.
zero_optimization
()
else
None
,
'lr_scheduler'
:
self
.
lr_scheduler
.
state_dict
()
if
self
.
lr_scheduler
is
not
None
else
None
,
'csr_tensor_module_names'
:
self
.
csr_tensor_module_names
,
'skipped_steps'
:
self
.
skipped_steps
,
'global_steps'
:
self
.
global_steps
,
}
state
.
update
(
client_state
)
logging
.
info
(
'Saving model checkpoint: {}'
.
format
(
save_path
))
torch
.
save
(
state
,
save_path
)
def
_save_zero_checkpoint
(
self
,
save_path
,
tag
):
try
:
zero_checkpoint_name
=
self
.
_get_zero_ckpt_name
(
save_path
,
tag
)
#self._ensure_directory_exists(zero_checkpoint_name)
except
:
logging
.
error
(
f
'Failed Saving Zero model checkpoint to
{
save_path
}
with tag
{
tag
}
'
)
zero_sd
=
{
'optimizer_state_dict'
:
self
.
optimizer
.
state_dict
()}
torch
.
save
(
zero_sd
,
zero_checkpoint_name
)
logging
.
info
(
'zero checkpoint saved {}'
.
format
(
zero_checkpoint_name
))
deepspeed/pt/deepspeed_lr_schedules.py
0 → 100644
View file @
7a9fbe67
"""
Copyright 2019 The Microsoft DeepSpeed Team
Implementation of learning rate schedules.
Taken and modified from PyTorch v1.0.1 source
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
"""
import
argparse
from
torch.optim
import
Optimizer
from
typing
import
Union
,
List
import
math
from
deepspeed.pt.deepspeed_constants
import
*
LR_SCHEDULE
=
'lr_schedule'
LR_RANGE_TEST
=
'LRRangeTest'
ONE_CYCLE
=
'OneCycle'
WARMUP_LR
=
'WarmupLR'
VALID_LR_SCHEDULES
=
[
LR_RANGE_TEST
,
ONE_CYCLE
,
WARMUP_LR
]
LR_RANGE_TEST_MIN_LR
=
'lr_range_test_min_lr'
LR_RANGE_TEST_STEP_RATE
=
'lr_range_test_step_rate'
LR_RANGE_TEST_STEP_SIZE
=
'lr_range_test_step_size'
LR_RANGE_TEST_STAIRCASE
=
'lr_range_test_staircase'
EDGE_VALUE
=
'edge_value'
MID_VALUE
=
'mid_value'
CYCLE_FIRST_STEP_SIZE
=
'cycle_first_step_size'
CYCLE_FIRST_STAIR_COUNT
=
'cycle_first_stair_count'
CYCLE_SECOND_STEP_SIZE
=
'cycle_second_step_size'
CYCLE_SECOND_STAIR_COUNT
=
'cycle_second_stair_count'
DECAY_STEP_SIZE
=
'decay_step_size'
CYCLE_MIN_LR
=
'cycle_min_lr'
CYCLE_MAX_LR
=
'cycle_max_lr'
DECAY_LR_RATE
=
'decay_lr_rate'
CYCLE_MIN_MOM
=
'cycle_min_mom'
CYCLE_MAX_MOM
=
'cycle_max_mom'
DECAY_MOM_RATE
=
'decay_mom_rate'
WARMUP_MIN_LR
=
'warmup_min_lr'
WARMUP_MAX_LR
=
'warmup_max_lr'
WARMUP_NUM_STEPS
=
'warmup_num_steps'
def
add_tuning_arguments
(
parser
):
group
=
parser
.
add_argument_group
(
'Convergence Tuning'
,
'Convergence tuning configurations'
)
# LR scheduler
group
.
add_argument
(
'--lr_schedule'
,
type
=
str
,
default
=
None
,
help
=
'LR schedule for training.'
)
# Learning rate range test
group
.
add_argument
(
"--lr_range_test_min_lr"
,
type
=
float
,
default
=
0.001
,
help
=
'Starting lr value.'
)
group
.
add_argument
(
"--lr_range_test_step_rate"
,
type
=
float
,
default
=
1.0
,
help
=
'scaling rate for LR range test.'
)
group
.
add_argument
(
"--lr_range_test_step_size"
,
type
=
int
,
default
=
1000
,
help
=
'training steps per LR change.'
)
group
.
add_argument
(
"--lr_range_test_staircase"
,
type
=
bool
,
default
=
False
,
help
=
'use staircase scaling for LR range test.'
)
# OneCycle schedule
group
.
add_argument
(
"--cycle_first_step_size"
,
type
=
int
,
default
=
1000
,
help
=
'size of first step of 1Cycle schedule (training steps).'
)
group
.
add_argument
(
"--cycle_first_stair_count"
,
type
=
int
,
default
=-
1
,
help
=
'first stair count for 1Cycle schedule.'
)
group
.
add_argument
(
"--cycle_second_step_size"
,
type
=
int
,
default
=-
1
,
help
=
'size of second step of 1Cycle schedule (default first_step_size).'
)
group
.
add_argument
(
"--cycle_second_stair_count"
,
type
=
int
,
default
=-
1
,
help
=
'second stair count for 1Cycle schedule.'
)
group
.
add_argument
(
"--decay_step_size"
,
type
=
int
,
default
=
1000
,
help
=
'size of intervals for applying post cycle decay (training steps).'
)
# 1Cycle LR
group
.
add_argument
(
"--cycle_min_lr"
,
type
=
float
,
default
=
0.01
,
help
=
'1Cycle LR lower bound.'
)
group
.
add_argument
(
"--cycle_max_lr"
,
type
=
float
,
default
=
0.1
,
help
=
'1Cycle LR upper bound.'
)
group
.
add_argument
(
"--decay_lr_rate"
,
type
=
float
,
default
=
0.0
,
help
=
'post cycle LR decay rate.'
)
# 1Cycle Momentum
group
.
add_argument
(
'--cycle_momentum'
,
default
=
False
,
action
=
'store_true'
,
help
=
'Enable 1Cycle momentum schedule.'
)
group
.
add_argument
(
"--cycle_min_mom"
,
type
=
float
,
default
=
0.8
,
help
=
'1Cycle momentum lower bound.'
)
group
.
add_argument
(
"--cycle_max_mom"
,
type
=
float
,
default
=
0.9
,
help
=
'1Cycle momentum upper bound.'
)
group
.
add_argument
(
"--decay_mom_rate"
,
type
=
float
,
default
=
0.0
,
help
=
'post cycle momentum decay rate.'
)
# Warmup LR
group
.
add_argument
(
'--warmup_min_lr'
,
type
=
float
,
default
=
0
,
help
=
'WarmupLR minimum/initial LR value'
)
group
.
add_argument
(
'--warmup_max_lr'
,
type
=
float
,
default
=
0.001
,
help
=
'WarmupLR maximum LR value.'
)
group
.
add_argument
(
'--warmup_num_steps'
,
type
=
int
,
default
=
1000
,
help
=
'WarmupLR step count for LR warmup.'
)
return
parser
def
parse_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
=
add_tuning_arguments
(
parser
)
lr_sched_args
,
unknown_args
=
parser
.
parse_known_args
()
return
lr_sched_args
,
unknown_args
def
override_lr_range_test_params
(
args
,
params
):
if
hasattr
(
args
,
LR_RANGE_TEST_MIN_LR
)
and
args
.
lr_range_test_min_lr
is
not
None
:
params
[
LR_RANGE_TEST_MIN_LR
]
=
args
.
lr_range_test_min_lr
if
hasattr
(
args
,
LR_RANGE_TEST_STEP_RATE
)
and
args
.
lr_range_test_step_rate
is
not
None
:
params
[
LR_RANGE_TEST_STEP_RATE
]
=
args
.
lr_range_test_step_rate
if
hasattr
(
args
,
LR_RANGE_TEST_STEP_SIZE
)
and
args
.
lr_range_test_step_size
is
not
None
:
params
[
LR_RANGE_TEST_STEP_SIZE
]
=
args
.
lr_range_test_step_size
if
hasattr
(
args
,
LR_RANGE_TEST_STAIRCASE
)
and
args
.
lr_range_test_staircase
is
not
None
:
params
[
LR_RANGE_TEST_STAIRCASE
]
=
args
.
lr_range_test_staircase
def
override_1cycle_params
(
args
,
params
):
if
hasattr
(
args
,
CYCLE_FIRST_STEP_SIZE
)
and
args
.
cycle_first_step_size
is
not
None
:
params
[
CYCLE_FIRST_STEP_SIZE
]
=
args
.
cycle_first_step_size
if
hasattr
(
args
,
CYCLE_FIRST_STAIR_COUNT
)
and
args
.
cycle_first_stair_count
is
not
None
:
params
[
CYCLE_FIRST_STAIR_COUNT
]
=
args
.
cycle_first_stair_count
if
hasattr
(
args
,
CYCLE_SECOND_STEP_SIZE
)
and
args
.
cycle_second_step_size
is
not
None
:
params
[
CYCLE_SECOND_STEP_SIZE
]
=
args
.
cycle_second_step_size
if
hasattr
(
args
,
CYCLE_SECOND_STAIR_COUNT
)
and
args
.
cycle_second_stair_count
is
not
None
:
params
[
CYCLE_SECOND_STAIR_COUNT
]
=
args
.
cycle_second_stair_count
if
hasattr
(
args
,
DECAY_STEP_SIZE
)
and
args
.
decay_step_size
is
not
None
:
params
[
DECAY_STEP_SIZE
]
=
args
.
decay_step_size
# 1Cycle LR params
if
hasattr
(
args
,
CYCLE_MIN_LR
)
and
args
.
cycle_min_lr
is
not
None
:
params
[
CYCLE_MIN_LR
]
=
args
.
cycle_min_lr
if
hasattr
(
args
,
CYCLE_MAX_LR
)
and
args
.
cycle_max_lr
is
not
None
:
params
[
CYCLE_MAX_LR
]
=
args
.
cycle_max_lr
if
hasattr
(
args
,
DECAY_LR_RATE
)
and
args
.
decay_lr_rate
is
not
None
:
params
[
DECAY_LR_RATE
]
=
args
.
decay_lr_rate
# 1Cycle MOM params
if
hasattr
(
args
,
CYCLE_MIN_MOM
)
and
args
.
cycle_min_mom
is
not
None
:
params
[
CYCLE_MIN_MOM
]
=
args
.
cycle_min_mom
if
hasattr
(
args
,
CYCLE_MAX_MOM
)
and
args
.
cycle_max_mom
is
not
None
:
params
[
CYCLE_MAX_MOM
]
=
args
.
cycle_max_mom
if
hasattr
(
args
,
DECAY_MOM_RATE
)
and
args
.
decay_mom_rate
is
not
None
:
params
[
DECAY_MOM_RATE
]
=
args
.
decay_mom_rate
def
override_warmupLR_params
(
args
,
params
):
if
hasattr
(
args
,
WARMUP_MIN_LR
)
and
args
.
warmup_min_lr
is
not
None
:
params
[
WARMUP_MIN_LR
]
=
args
.
warmup_min_lr
if
hasattr
(
args
,
WARMUP_MAX_LR
)
and
args
.
warmup_max_lr
is
not
None
:
params
[
WARMUP_MAX_LR
]
=
args
.
warmup_max_lr
if
hasattr
(
args
,
WARMUP_NUM_STEPS
)
and
args
.
warmup_num_steps
is
not
None
:
params
[
WARMUP_NUM_STEPS
]
=
args
.
warmup_num_steps
def
override_params
(
args
,
params
):
# LR range test params
override_lr_range_test_params
(
args
,
params
)
# 1Cycle params
override_1cycle_params
(
args
,
params
)
# WarmupLR params
override_warmupLR_params
(
args
,
params
)
def
get_config_from_args
(
args
):
if
not
hasattr
(
args
,
LR_SCHEDULE
)
or
args
.
lr_schedule
is
None
:
return
None
,
'--{} not specified on command line'
.
format
(
LR_SCHEDULE
)
if
not
args
.
lr_schedule
in
VALID_LR_SCHEDULES
:
return
None
,
'{} is not supported LR schedule'
.
format
(
args
.
lr_schedule
)
config
=
{}
config
[
'type'
]
=
args
.
lr_schedule
config
[
'params'
]
=
{}
if
args
.
lr_schedule
==
LR_RANGE_TEST
:
override_lr_range_test_params
(
args
,
config
[
'params'
])
elif
args
.
lr_schedule
==
ONE_CYCLE
:
override_1cycle_params
(
args
,
config
[
'params'
])
else
:
override_warmupLR_params
(
args
,
config
[
'params'
])
return
config
,
None
def
get_lr_from_config
(
config
):
if
not
'type'
in
config
:
return
None
,
'LR schedule type not defined in config'
if
not
'params'
in
config
:
return
None
,
'LR schedule params not defined in config'
lr_schedule
=
config
[
'type'
]
lr_params
=
config
[
'params'
]
if
not
lr_schedule
in
VALID_LR_SCHEDULES
:
return
None
,
'{} is not a valid LR schedule'
.
format
(
lr_schedule
)
if
lr_schedule
==
LR_RANGE_TEST
:
return
lr_params
[
LR_RANGE_TEST_MIN_LR
],
''
elif
lr_schedule
==
ONE_CYCLE
:
return
lr_params
[
CYCLE_MAX_LR
],
''
else
:
# Warmup LR
return
lr_params
[
WARMUP_MAX_LR
],
''
class
LRRangeTest
(
object
):
"""Sets the learning rate of each parameter group according to
learning rate range test (LRRT) policy. The policy increases learning
rate starting from a base value with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
configure the LR boundaries for Cylic LR schedules.
LRRT changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_range_test_min_lr (float or list): Initial learning rate which is the
lower boundary in the range test for each parameter group.
lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
lr_range_test_staircase (bool): Scale in staircase fashion, rather than continous. Default: False.
last_batch_iteration (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_batch_iteration=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.LRRangeTest(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
_A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
https://arxiv.org/abs/1803.09820
"""
def
__init__
(
self
,
optimizer
:
Optimizer
,
lr_range_test_min_lr
:
float
=
1e-3
,
lr_range_test_step_size
:
int
=
2000
,
lr_range_test_step_rate
:
float
=
1.0
,
lr_range_test_staircase
:
bool
=
False
,
last_batch_iteration
:
int
=
-
1
):
if
not
isinstance
(
optimizer
,
Optimizer
):
raise
TypeError
(
'{} is not an Optimizer'
.
format
(
type
(
optimizer
).
__name__
))
self
.
optimizer
=
optimizer
if
isinstance
(
lr_range_test_min_lr
,
list
)
or
isinstance
(
lr_range_test_min_lr
,
tuple
):
if
len
(
lr_range_test_min_lr
)
!=
len
(
optimizer
.
param_groups
):
raise
ValueError
(
"expected {} lr_range_test_min_lr, got {}"
.
format
(
len
(
optimizer
.
param_groups
),
len
(
lr_range_test_min_lr
)))
self
.
min_lr
=
list
(
lr_range_test_min_lr
)
else
:
self
.
min_lr
=
[
lr_range_test_min_lr
]
*
len
(
optimizer
.
param_groups
)
self
.
step_size
=
lr_range_test_step_size
self
.
step_rate
=
lr_range_test_step_rate
self
.
last_batch_iteration
=
last_batch_iteration
self
.
staircase
=
lr_range_test_staircase
self
.
interval_fn
=
self
.
_staircase_interval
if
lr_range_test_staircase
else
self
.
_continous_interval
if
last_batch_iteration
==
-
1
:
self
.
_update_optimizer
(
self
.
min_lr
)
def
_staircase_interval
(
self
):
return
math
.
floor
(
float
(
self
.
last_batch_iteration
)
/
self
.
step_size
)
def
_continous_interval
(
self
):
return
float
(
self
.
last_batch_iteration
)
/
self
.
step_size
def
_get_increase
(
self
):
return
(
1
+
self
.
step_rate
*
self
.
interval_fn
())
def
get_lr
(
self
):
lr_increase
=
self
.
_get_increase
()
return
[
lr_range_test_min_lr
*
lr_increase
for
lr_range_test_min_lr
in
self
.
min_lr
]
def
_update_optimizer
(
self
,
group_lrs
):
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
group_lrs
):
param_group
[
'lr'
]
=
lr
def
step
(
self
,
batch_iteration
=
None
):
if
batch_iteration
is
None
:
batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
batch_iteration
self
.
_update_optimizer
(
self
.
get_lr
())
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
def
load_state_dict
(
self
,
sd
):
self
.
last_batch_iteration
=
sd
[
'last_batch_iteration'
]
class
OneCycle
(
object
):
"""Sets the learning rate of each parameter group according to
1Cycle learning rate policy (1CLR). 1CLR is a variation of the
Cyclical Learning Rate (CLR) policy that involves one cycle followed by
decay. The policy simultaneously cycles the learning rate (and momentum)
between two boundaries with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters`_.
1CLR policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This implementation was adapted from the github repo: `pytorch/pytorch`_
Args:
optimizer (Optimizer): Wrapped optimizer.
cycle_min_lr (float or list): Initial learning rate which is the
lower boundary in the cycle for each parameter group.
cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
The lr at any cycle is the sum of cycle_min_lr
and some scaling of the amplitude; therefore
cycle_max_lr may not actually be reached depending on
scaling function.
decay_lr_rate(float): Decay rate for learning rate. Default: 0.
cycle_first_step_size (int): Number of training iterations in the
increasing half of a cycle. Default: 2000
cycle_second_step_size (int): Number of training iterations in the
decreasing half of a cycle. If cycle_second_step_size is None,
it is set to cycle_first_step_size. Default: None
cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
Default: True
cycle_min_mom (float or list): Initial momentum which is the
lower boundary in the cycle for each parameter group.
Default: 0.8
cycle_max_mom (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
The momentum at any cycle is the difference of cycle_max_mom
and some scaling of the amplitude; therefore
cycle_min_mom may not actually be reached depending on
scaling function. Default: 0.9
decay_mom_rate (float): Decay rate for momentum. Default: 0.
last_batch_iteration (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_batch_iteration=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.OneCycle(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
"""
def
__init__
(
self
,
optimizer
,
cycle_min_lr
,
cycle_max_lr
,
decay_lr_rate
=
0.
,
cycle_first_step_size
=
2000
,
cycle_second_step_size
=
None
,
cycle_first_stair_count
=
0
,
cycle_second_stair_count
=
None
,
decay_step_size
=
0
,
cycle_momentum
=
True
,
cycle_min_mom
=
0.8
,
cycle_max_mom
=
0.9
,
decay_mom_rate
=
0.
,
last_batch_iteration
=-
1
):
if
not
isinstance
(
optimizer
,
Optimizer
):
raise
TypeError
(
'{} is not an Optimizer'
.
format
(
type
(
optimizer
).
__name__
))
self
.
optimizer
=
optimizer
self
.
min_lrs
=
[
cycle_min_lr
]
*
len
(
optimizer
.
param_groups
)
if
last_batch_iteration
==
-
1
:
for
lr
,
group
in
zip
(
self
.
min_lrs
,
optimizer
.
param_groups
):
group
[
'lr'
]
=
lr
self
.
max_lrs
=
[
cycle_max_lr
]
*
len
(
optimizer
.
param_groups
)
cycle_first_step_size
=
float
(
cycle_first_step_size
)
cycle_second_step_size
=
float
(
cycle_second_step_size
)
if
cycle_second_step_size
is
not
None
else
cycle_first_step_size
self
.
total_size
=
cycle_first_step_size
+
cycle_second_step_size
self
.
step_ratio
=
cycle_first_step_size
/
self
.
total_size
self
.
first_stair_count
=
cycle_first_stair_count
self
.
second_stair_count
=
cycle_first_stair_count
if
cycle_second_stair_count
is
None
else
cycle_second_stair_count
self
.
decay_lr_rate
=
decay_lr_rate
self
.
decay_mom_rate
=
decay_mom_rate
self
.
decay_step_size
=
decay_step_size
self
.
min_moms
=
[(
cycle_min_mom
,
0.99
)]
*
len
(
optimizer
.
param_groups
)
self
.
max_moms
=
[(
cycle_max_mom
,
0.99
)]
*
len
(
optimizer
.
param_groups
)
self
.
cycle_momentum
=
cycle_momentum
self
.
last_batch_iteration
=
last_batch_iteration
if
cycle_momentum
:
if
'betas'
not
in
optimizer
.
defaults
:
raise
ValueError
(
'optimizer must support betas with `cycle_momentum` option enabled'
)
if
last_batch_iteration
==
-
1
:
for
momentum
,
group
in
zip
(
self
.
min_moms
,
optimizer
.
param_groups
):
group
[
'betas'
]
=
momentum
def
_get_cycle_lr
(
self
):
cycle
=
math
.
floor
(
1
+
self
.
last_batch_iteration
/
self
.
total_size
)
x
=
1.
+
self
.
last_batch_iteration
/
self
.
total_size
-
cycle
if
x
<=
self
.
step_ratio
:
scale_factor
=
x
/
self
.
step_ratio
else
:
scale_factor
=
(
x
-
1
)
/
(
self
.
step_ratio
-
1
)
lrs
=
[]
for
cycle_min_lr
,
cycle_max_lr
in
zip
(
self
.
min_lrs
,
self
.
max_lrs
):
base_height
=
(
cycle_max_lr
-
cycle_min_lr
)
*
scale_factor
lr
=
cycle_min_lr
+
base_height
lrs
.
append
(
lr
)
if
self
.
cycle_momentum
:
momentums
=
[]
for
base_betas
,
max_betas
in
zip
(
self
.
min_moms
,
self
.
max_moms
):
cycle_min_mom
=
base_betas
[
0
]
cycle_max_mom
=
max_betas
[
0
]
base_height
=
(
cycle_max_mom
-
cycle_min_mom
)
*
scale_factor
momentum
=
cycle_max_mom
-
base_height
momentums
.
append
((
momentum
,
base_betas
[
1
]))
for
param_group
,
momentum
in
zip
(
self
.
optimizer
.
param_groups
,
momentums
):
param_group
[
'betas'
]
=
momentum
return
lrs
def
_get_decay_lr
(
self
,
decay_batch_iteration
):
"""Calculates the learning rate at batch index. This function is used
after the cycle completes and post cycle decaying of lr/mom is enabled.
This function treats `self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
decay_interval
=
decay_batch_iteration
/
self
.
decay_step_size
lr_decay_factor
=
(
1
+
self
.
decay_lr_rate
*
decay_interval
)
lrs
=
[
cycle_min_lr
*
lr_decay_factor
for
cycle_min_lr
in
self
.
min_lrs
]
if
self
.
cycle_momentum
:
mom_decay_factor
=
(
1
+
self
.
decay_mom_rate
*
decay_interval
)
momentums
=
[(
beta0
*
mom_decay_factor
,
beta1
)
for
beta0
,
beta1
in
self
.
max_moms
]
for
param_group
,
momentum
in
zip
(
self
.
optimizer
.
param_groups
,
momentums
):
param_group
[
'betas'
]
=
momentum
return
lrs
def
get_lr
(
self
):
"""Calculates the learning rate at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
if
self
.
last_batch_iteration
<=
self
.
total_size
:
return
self
.
_get_cycle_lr
()
else
:
return
self
.
_get_decay_lr
(
self
.
last_batch_iteration
-
self
.
total_size
)
def
step
(
self
,
batch_iteration
=
None
):
if
batch_iteration
is
None
:
batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
batch_iteration
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
def
load_state_dict
(
self
,
sd
):
self
.
last_batch_iteration
=
sd
[
'last_batch_iteration'
]
class
WarmupLR
(
object
):
"""Increase the learning rate of each parameter group from min lr to max lr
over warmup_num_steps steps, and then fix at max lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_min_lr (float or list): minimum learning rate. Default: 0
warmup_max_lr (float or list): maximum learning rate. Default: 0.001
warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
last_batch_iteration (int): The index of the last batch. Default: -1.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.WarmupLR(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
"""
def
__init__
(
self
,
optimizer
:
Optimizer
,
warmup_min_lr
:
float
=
0.0
,
warmup_max_lr
:
float
=
0.001
,
warmup_num_steps
:
int
=
1000
,
last_batch_iteration
:
int
=
-
1
):
self
.
optimizer
=
optimizer
self
.
min_lrs
=
self
.
_format_param
(
optimizer
,
warmup_min_lr
,
"min_lr"
)
self
.
max_lrs
=
self
.
_format_param
(
optimizer
,
warmup_max_lr
,
"max_lr"
)
self
.
delta_lrs
=
[
big
-
small
for
big
,
small
in
zip
(
self
.
max_lrs
,
self
.
min_lrs
)]
self
.
warmup_num_steps
=
warmup_num_steps
self
.
inverse_log_warm_up
=
1.0
/
math
.
log
(
warmup_num_steps
)
self
.
last_batch_iteration
=
last_batch_iteration
def
get_lr
(
self
):
gamma
=
self
.
_get_gamma
()
return
[
min_lr
+
(
delta_lr
*
gamma
)
for
min_lr
,
delta_lr
in
zip
(
self
.
min_lrs
,
self
.
delta_lrs
)
]
def
step
(
self
,
last_batch_iteration
=
None
):
if
last_batch_iteration
is
None
:
last_batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
last_batch_iteration
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
def
load_state_dict
(
self
,
sd
):
self
.
last_batch_iteration
=
sd
[
'last_batch_iteration'
]
def
_get_gamma
(
self
):
if
self
.
last_batch_iteration
<
self
.
warmup_num_steps
:
return
self
.
inverse_log_warm_up
*
math
.
log
(
self
.
last_batch_iteration
+
1
)
else
:
return
1.0
def
_format_param
(
self
,
optimizer
,
param_value
,
param_name
):
if
isinstance
(
param_value
,
list
)
or
isinstance
(
param_value
,
tuple
):
if
len
(
param_value
)
!=
len
(
optimizer
.
param_groups
):
raise
ValueError
(
"expected {} value for {}, got {}"
.
format
(
len
(
optimizer
.
param_groups
),
param_name
,
FileNotFoundError
(
param_value
)))
return
list
(
param_value
)
else
:
return
[
param_value
]
*
len
(
optimizer
.
param_groups
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment