Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
7a9fbe67
Unverified
Commit
7a9fbe67
authored
Feb 03, 2020
by
Olatunji Ruwase
Committed by
GitHub
Feb 03, 2020
Browse files
Add files via upload
parent
dc226fdb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1675 additions
and
0 deletions
+1675
-0
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+1021
-0
deepspeed/pt/deepspeed_lr_schedules.py
deepspeed/pt/deepspeed_lr_schedules.py
+654
-0
No files found.
deepspeed/pt/deepspeed_light.py
0 → 100644
View file @
7a9fbe67
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import
logging
import
torch
import
os
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
tensorboardX
import
SummaryWriter
from
deepspeed.pt.deepspeed_timer
import
ThroughputTimer
,
SynchronizedWallClockTimer
from
deepspeed.pt.deepspeed_zero_optimizer
import
FP16_DeepSpeedZeroOptimizer
from
deepspeed.pt.fp16_optimizer
import
FP16_Optimizer
from
deepspeed.pt.fp16_unfused_optimizer
import
FP16_UnfusedOptimizer
from
deepspeed.pt.deepspeed_fused_lamb
import
FusedLamb
from
deepspeed.pt.deepspeed_config
import
DeepSpeedConfig
,
\
ADAM_OPTIMIZER
,
LAMB_OPTIMIZER
,
DEEPSPEED_OPTIMIZERS
from
deepspeed.pt.deepspeed_dataloader
import
DeepSpeedDataLoader
from
deepspeed.pt.deepspeed_constants
import
ROUTE_TRAIN
,
ROUTE_PREDICT
,
\
ROUTE_EVAL
import
deepspeed.pt.deepspeed_lr_schedules
as
lr_schedules
from
deepspeed.pt.deepspeed_csr_tensor
import
CSRTensor
from
apex
import
amp
from
apex.optimizers.fused_adam
import
FusedAdam
MEMORY_OPT_ALLREDUCE_SIZE
=
500000000
SUMMARY_WRITER_DIR_NAME
=
"JobId"
try
:
from
apex_C
import
flatten
from
apex_C
import
unflatten
except
ImportError
:
try
:
_
=
warned_flatten
except
NameError
:
print
(
"Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten."
)
warned_flatten
=
True
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
torch._utils
import
_unflatten_dense_tensors
as
unflatten
def
split_half_float_double_csr
(
tensors
):
dtypes
=
[
"torch.cuda.HalfTensor"
,
"torch.cuda.FloatTensor"
,
"torch.cuda.DoubleTensor"
,
CSRTensor
.
type
()
]
buckets
=
[]
for
i
,
dtype
in
enumerate
(
dtypes
):
bucket
=
[
t
for
t
in
tensors
if
t
.
type
()
==
dtype
]
if
bucket
:
buckets
.
append
((
dtype
,
bucket
))
return
buckets
def
_initialize_parameter_parallel_groups
(
parameter_parallel_size
=
None
):
data_parallel_size
=
int
(
dist
.
get_world_size
())
if
parameter_parallel_size
is
None
:
parameter_parallel_size
=
int
(
data_parallel_size
)
print
(
data_parallel_size
,
parameter_parallel_size
)
assert
data_parallel_size
%
parameter_parallel_size
==
0
,
\
'world size should be divisible by parameter parallel size'
rank
=
dist
.
get_rank
()
my_group
=
None
for
i
in
range
(
dist
.
get_world_size
()
//
parameter_parallel_size
):
ranks
=
range
(
i
*
parameter_parallel_size
,
(
i
+
1
)
*
parameter_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
my_group
=
group
return
my_group
def
print_configuration
(
args
,
name
):
print
(
'{}:'
.
format
(
name
),
flush
=
True
)
for
arg
in
sorted
(
vars
(
args
)):
dots
=
'.'
*
(
29
-
len
(
arg
))
print
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)),
flush
=
True
)
class
DeepSpeedLight
(
Module
):
r
"""DeepSpeed engine for training.
"""
def
__init__
(
self
,
args
,
model
,
optimizer
=
None
,
model_parameters
=
None
,
training_data
=
None
,
lr_scheduler
=
None
,
mpu
=
None
,
dist_init_required
=
True
,
collate_fn
=
None
):
super
(
DeepSpeedLight
,
self
).
__init__
()
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"[%(levelname)s %(asctime)s] %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
)
self
.
client_optimizer
=
optimizer
self
.
client_model_parameters
=
model_parameters
self
.
client_lr_scheduler
=
lr_scheduler
self
.
training_data
=
training_data
self
.
collate_fn
=
collate_fn
self
.
mpu
=
mpu
self
.
data_parallel_group
=
None
self
.
global_steps
=
0
self
.
micro_steps
=
0
self
.
skipped_steps
=
0
self
.
gradient_predivide_factor
=
1.0
self
.
gradient_average
=
True
self
.
warn_unscaled_loss
=
True
if
dist_init_required
:
dist
.
init_process_group
(
backend
=
"nccl"
)
self
.
_do_args_sanity_check
(
args
)
self
.
_configure_with_arguments
(
args
,
mpu
)
self
.
_do_sanity_check
()
self
.
sample_count
=
0
if
self
.
tensorboard_enabled
():
self
.
summary_writer
=
self
.
get_summary_writer
()
self
.
_init_distributed
(
dist_init_required
)
# Throughput timer
self
.
tput_timer
=
ThroughputTimer
(
batch_size
=
self
.
train_micro_batch_size_per_gpu
(),
num_workers
=
self
.
world_size
,
monitor_memory
=
False
)
self
.
training_dataloader
=
self
.
deepspeed_io
(
training_data
)
if
training_data
else
None
# Configure distributed model
self
.
_configure_distributed_model
(
model
)
# Configure optimizer and scheduler
self
.
optimizer
=
None
self
.
lr_scheduler
=
None
if
model_parameters
or
optimizer
:
self
.
_configure_optimizer
(
optimizer
,
model_parameters
)
self
.
_configure_lr_scheduler
(
lr_scheduler
)
self
.
_report_progress
(
0
)
# Configure wall clock timer
self
.
timers
=
SynchronizedWallClockTimer
()
# Bookkeeping for csr support
self
.
csr_tensor_module_names
=
set
()
if
self
.
sparse_gradients_enabled
():
for
name
,
module
in
self
.
module
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Embedding
):
self
.
csr_tensor_module_names
.
add
(
name
)
logging
.
info
(
"Will convert {} to sparse (csr) "
"tensor during training"
.
format
(
name
))
self
.
save_non_zero_checkpoint
=
False
self
.
save_zero_checkpoint
=
False
self
.
_configure_checkpointing
(
dist_init_required
)
if
self
.
global_rank
==
0
:
self
.
_config
.
print
(
'DeepSpeedLight configuration'
)
if
self
.
dump_state
():
print_configuration
(
self
,
'DeepSpeedLight'
)
def
tensorboard_enabled
(
self
):
return
self
.
_config
.
tensorboard_enabled
def
tensorboard_output_path
(
self
):
return
self
.
_config
.
tensorboard_output_path
def
tensorboard_job_name
(
self
):
return
self
.
_config
.
tensorboard_job_name
def
get_summary_writer
(
self
,
name
=
"DeepSpeedJobName"
,
base
=
os
.
environ
[
"HOME"
]
+
"/tensorboard"
):
if
self
.
tensorboard_job_name
():
name
=
self
.
tensorboard_job_name
()
if
self
.
tensorboard_output_path
():
return
SummaryWriter
(
log_dir
=
self
.
tensorboard_output_path
())
if
'DLWS_JOB_ID'
in
os
.
environ
:
SUMMARY_WRITER_DIR_NAME
=
os
.
environ
[
'DLWS_JOB_ID'
]
+
"/logs"
return
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
base
,
SUMMARY_WRITER_DIR_NAME
,
name
))
def
wall_clock_breakdown
(
self
):
return
self
.
_config
.
wall_clock_breakdown
def
sparse_gradients_enabled
(
self
):
return
self
.
_config
.
sparse_gradients_enabled
def
train_batch_size
(
self
):
return
self
.
_config
.
train_batch_size
def
train_micro_batch_size_per_gpu
(
self
):
return
self
.
_config
.
train_micro_batch_size_per_gpu
def
optimizer_name
(
self
):
return
self
.
_config
.
optimizer_name
def
optimizer_params
(
self
):
return
self
.
_config
.
optimizer_params
def
scheduler_name
(
self
):
return
self
.
_config
.
scheduler_name
def
scheduler_params
(
self
):
return
self
.
_config
.
scheduler_params
def
zero_optimization
(
self
):
return
self
.
_config
.
zero_enabled
def
allgather_size
(
self
):
return
self
.
_config
.
allgather_size
def
fp16_enabled
(
self
):
return
self
.
_config
.
fp16_enabled
def
loss_scale
(
self
):
return
self
.
_config
.
loss_scale
def
gradient_accumulation_steps
(
self
):
return
self
.
_config
.
gradient_accumulation_steps
def
allreduce_always_fp32
(
self
):
return
self
.
_config
.
allreduce_always_fp32
def
postscale_gradients
(
self
):
return
not
self
.
_config
.
prescale_gradients
def
steps_per_print
(
self
):
return
self
.
_config
.
steps_per_print
def
disable_allgather
(
self
):
return
self
.
_config
.
disable_allgather
def
dump_state
(
self
):
return
self
.
_config
.
dump_state
def
gradient_clipping
(
self
):
return
self
.
_config
.
gradient_clipping
def
dynamic_loss_scale
(
self
):
return
self
.
_config
.
loss_scale
==
0
def
initial_dynamic_scale
(
self
):
return
self
.
_config
.
initial_dynamic_scale
def
dynamic_loss_scale_args
(
self
):
return
self
.
_config
.
dynamic_loss_scale_args
def
_configure_lr_scheduler
(
self
,
client_lr_scheduler
):
# First check for scheduler in json configuration
lr_scheduler
=
self
.
_scheduler_from_config
(
self
.
optimizer
)
if
lr_scheduler
:
logging
.
info
(
f
'DeepSpeed using configured LR scheduler =
{
self
.
scheduler_name
()
}
'
)
self
.
lr_scheduler
=
lr_scheduler
else
:
logging
.
warning
(
'DeepSpeed using client LR scheduler'
)
self
.
lr_scheduler
=
client_lr_scheduler
logging
.
info
(
f
'DeepSpeed LR Scheduler =
{
self
.
lr_scheduler
}
'
)
def
_configure_checkpointing
(
self
,
dist_init_required
):
dp_rank
=
torch
.
distributed
.
get_rank
(
)
if
self
.
mpu
is
None
else
self
.
mpu
.
get_data_parallel_rank
()
#only the first data parallel process needs to store the model checkpoint
self
.
save_non_zero_checkpoint
=
True
if
dp_rank
==
0
else
False
if
self
.
zero_optimization
():
pp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
optimizer
.
dp_process_group
)
#only the first parameter parallel process needs to store the optimizer state checkpoints for zero
self
.
save_zero_checkpoint
=
True
if
pp_rank
==
dp_rank
else
False
def
_scheduler_from_config
(
self
,
optimizer
):
scheduler_name
=
self
.
scheduler_name
()
if
scheduler_name
is
not
None
:
if
hasattr
(
lr_schedules
,
scheduler_name
):
scheduler
=
getattr
(
lr_schedules
,
scheduler_name
)
else
:
assert
hasattr
(
torch
.
optim
.
lr_scheduler
,
scheduler_name
),
\
f
"DeepSpeed does not recognize LR scheduler
{
scheduler_name
}
"
scheduler
=
getattr
(
torch
.
optim
.
lr_scheduler
,
scheduler_name
)
scheduler_params
=
self
.
scheduler_params
()
instantiated_scheduler
=
scheduler
(
optimizer
,
**
scheduler_params
)
return
instantiated_scheduler
else
:
return
None
def
_init_distributed
(
self
,
dist_init_required
):
if
self
.
local_rank
>=
0
:
torch
.
cuda
.
set_device
(
self
.
local_rank
)
self
.
device
=
torch
.
device
(
"cuda"
,
self
.
local_rank
)
self
.
world_size
=
dist
.
get_world_size
()
self
.
global_rank
=
dist
.
get_rank
()
logging
.
info
(
"Set device to local rank {} within node."
.
format
(
self
.
local_rank
))
else
:
self
.
world_size
=
1
self
.
global_rank
=
0
self
.
device
=
torch
.
device
(
"cuda"
)
# Configure based on command line arguments
def
_configure_with_arguments
(
self
,
args
,
mpu
):
self
.
local_rank
=
args
.
local_rank
if
hasattr
(
args
,
'local_rank'
)
else
0
self
.
_config
=
DeepSpeedConfig
(
args
.
deepspeed_config
,
mpu
)
# Validate command line arguments
def
_do_args_sanity_check
(
self
,
args
):
assert
hasattr
(
args
,
'local_rank'
)
and
type
(
args
.
local_rank
)
==
int
,
\
'DeepSpeed requires integer command line parameter --local_rank'
assert
hasattr
(
args
,
'deepspeed_config'
)
and
args
.
deepspeed_config
is
not
None
,
\
'DeepSpeed requires --deepspeed_config to specify configuration file'
assert
os
.
path
.
isfile
(
args
.
deepspeed_config
),
\
'DeepSpeed configuration file: {} is not an existing file'
.
format
(
args
.
deepspeed_config
)
def
_is_supported_optimizer
(
self
,
optimizer_name
):
return
optimizer_name
in
DEEPSPEED_OPTIMIZERS
or
\
getattr
(
torch
.
optim
,
optimizer_name
,
None
)
is
not
None
# Validate configuration based on command line arguments
def
_do_sanity_check
(
self
):
if
not
self
.
client_optimizer
:
assert
self
.
_is_supported_optimizer
(
self
.
optimizer_name
()),
\
'{} is not a supported DeepSpeed Optimizer'
.
format
(
self
.
optimizer_name
())
assert
self
.
client_model_parameters
,
\
'DeepSpeed {} optimizer requires parameters in initialize() call'
.
format
(
self
.
optimizer_name
())
if
self
.
optimizer_name
()
==
LAMB_OPTIMIZER
:
assert
self
.
dynamic_loss_scale
(),
\
'DeepSpeed {} optimizer requires dynamic loss scaling'
.
format
(
self
.
optimizer_name
())
def
_configure_distributed_model
(
self
,
model
):
self
.
module
=
model
if
self
.
fp16_enabled
():
self
.
module
.
half
()
self
.
module
.
to
(
self
.
device
)
if
self
.
mpu
is
None
:
self
.
data_parallel_group
=
_initialize_parameter_parallel_groups
()
self
.
dp_world_size
=
dist
.
get_world_size
()
src_rank
=
0
else
:
self
.
data_parallel_group
=
self
.
mpu
.
get_data_parallel_group
()
self
.
dp_world_size
=
self
.
mpu
.
get_data_parallel_world_size
()
src_rank
=
self
.
mpu
.
get_model_parallel_rank
()
for
p
in
self
.
module
.
parameters
():
if
torch
.
is_tensor
(
p
):
dist
.
broadcast
(
p
,
src_rank
,
group
=
self
.
data_parallel_group
)
# TODO: support new AMP optimizer
# self.module.half()
# self.module.to(self.local_rank)
#self.module, self.optimizer = amp.initialize(self.module, self.optimizer, opt_level="O2")
# Configure optimizer
def
_configure_optimizer
(
self
,
client_optimizer
,
model_parameters
):
if
client_optimizer
is
not
None
:
basic_optimizer
=
client_optimizer
logging
.
info
(
'Using client Optimizer as basic optimizer'
)
else
:
basic_optimizer
=
self
.
_configure_basic_optimizer
(
model_parameters
)
logging
.
info
(
'Using DeepSpeed Optimizer param name {} as basic optimizer'
.
format
(
self
.
optimizer_name
()))
logging
.
info
(
'DeepSpeed Basic Optimizer = {}'
.
format
(
basic_optimizer
))
if
self
.
zero_optimization
()
and
self
.
optimizer_name
()
==
ADAM_OPTIMIZER
:
self
.
optimizer
=
self
.
_configure_zero_optimizer
(
basic_optimizer
)
elif
self
.
fp16_enabled
():
self
.
optimizer
=
self
.
_configure_fp16_optimizer
(
basic_optimizer
)
else
:
self
.
optimizer
=
basic_optimizer
# logging.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
def
_configure_basic_optimizer
(
self
,
model_parameters
):
optimizer_parameters
=
self
.
optimizer_params
()
if
self
.
fp16_enabled
()
and
'max_grad_norm'
in
optimizer_parameters
.
keys
():
optimizer_parameters
[
'max_grad_norm'
]
=
0.0
if
self
.
optimizer_name
()
==
ADAM_OPTIMIZER
:
optimizer
=
FusedAdam
(
model_parameters
,
**
optimizer_parameters
)
elif
self
.
optimizer_name
()
==
LAMB_OPTIMIZER
:
optimizer
=
FusedLamb
(
model_parameters
,
**
optimizer_parameters
)
else
:
torch_optimizer
=
getattr
(
torch
.
optim
,
self
.
optimizer_name
())
optimizer
=
torch_optimizer
(
model_parameters
,
**
optimizer_parameters
)
return
optimizer
def
_configure_fp16_optimizer
(
self
,
optimizer
):
initial_dynamic_scale
=
self
.
initial_dynamic_scale
()
dynamic_loss_args
=
self
.
dynamic_loss_scale_args
()
clip_grad
=
self
.
gradient_clipping
()
if
self
.
optimizer_name
()
==
ADAM_OPTIMIZER
:
if
self
.
dynamic_loss_scale
():
logging
.
info
(
'Creating fp16 optimizer with dynamic loss scale'
)
optimizer
=
FP16_Optimizer
(
optimizer
,
dynamic_loss_scale
=
True
,
initial_dynamic_scale
=
initial_dynamic_scale
,
dynamic_loss_args
=
dynamic_loss_args
,
mpu
=
self
.
mpu
,
clip_grad
=
clip_grad
,
fused_adam_legacy
=
True
)
else
:
logging
.
info
(
'Creating fp16 optimizer with static loss scale: {}'
.
format
(
self
.
loss_scale
()))
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
self
.
loss_scale
(),
mpu
=
self
.
mpu
,
clip_grad
=
clip_grad
,
fused_adam_legacy
=
True
)
else
:
logging
.
info
(
'Creating fp16 unfused optimizer with dynamic loss scale'
)
optimizer
=
FP16_UnfusedOptimizer
(
optimizer
,
dynamic_loss_scale
=
self
.
dynamic_loss_scale
(),
dynamic_loss_args
=
dynamic_loss_args
,
mpu
=
self
.
mpu
,
clip_grad
=
clip_grad
,
fused_lamb_legacy
=
True
if
self
.
optimizer_name
()
==
LAMB_OPTIMIZER
else
False
)
return
optimizer
def
_configure_zero_optimizer
(
self
,
optimizer
):
logging
.
info
(
'Creating fp16 zero optimizer'
)
optimizer
=
FP16_DeepSpeedZeroOptimizer
(
optimizer
,
static_loss_scale
=
self
.
loss_scale
(),
dynamic_loss_scale
=
self
.
dynamic_loss_scale
(),
dynamic_loss_args
=
self
.
dynamic_loss_scale_args
(),
dp_process_group
=
self
.
data_parallel_group
,
clip_grad
=
self
.
gradient_clipping
(),
all_gather_partitions
=
not
self
.
disable_allgather
(),
allgather_size
=
self
.
allgather_size
(),
mpu
=
self
.
mpu
)
return
optimizer
def
deepspeed_io
(
self
,
dataset
,
batch_size
=
None
,
route
=
ROUTE_TRAIN
,
pin_memory
=
True
,
data_sampler
=
None
,
collate_fn
=
None
,
num_local_io_workers
=
None
):
if
not
isinstance
(
dataset
,
torch
.
utils
.
data
.
Dataset
):
raise
ValueError
(
"Training data must be a torch Dataset"
)
if
data_sampler
is
None
and
(
route
==
ROUTE_PREDICT
or
route
==
ROUTE_EVAL
):
data_sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
if
batch_size
is
None
:
batch_size
=
self
.
train_micro_batch_size_per_gpu
()
if
collate_fn
is
None
:
collate_fn
=
self
.
collate_fn
# Currently we only use timer in train route
deepspeed_io_timer
=
None
if
route
==
ROUTE_TRAIN
:
deepspeed_io_timer
=
self
.
tput_timer
return
DeepSpeedDataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
pin_memory
=
pin_memory
,
collate_fn
=
collate_fn
,
local_rank
=
self
.
local_rank
,
tput_timer
=
deepspeed_io_timer
,
num_local_io_workers
=
num_local_io_workers
,
data_sampler
=
data_sampler
)
def
train
(
self
):
r
"""
"""
self
.
warn_unscaled_loss
=
True
self
.
module
.
train
()
def
eval
(
self
):
r
"""
"""
self
.
warn_unscaled_loss
=
True
self
.
module
.
train
(
False
)
def
_scale_loss
(
self
,
loss
):
if
isinstance
(
loss
,
torch
.
Tensor
):
loss
=
loss
/
self
.
gradient_accumulation_steps
()
elif
isinstance
(
loss
,
tuple
)
and
isinstance
(
loss
[
0
],
torch
.
Tensor
):
loss
=
(
l
/
self
.
gradient_accumulation_steps
()
for
l
in
loss
)
elif
isinstance
(
loss
,
list
)
and
isinstance
(
loss
[
0
],
torch
.
Tensor
):
loss
=
[
l
/
self
.
gradient_accumulation_steps
()
for
l
in
loss
]
else
:
if
self
.
warn_unscaled_loss
:
logging
.
warning
(
f
'DeepSpeed unable to scale loss because of type:
{
type
(
loss
)
}
'
)
self
.
warn_unscaled_loss
=
False
return
loss
def
forward
(
self
,
*
inputs
,
**
kwargs
):
r
"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'forward_microstep'
).
start
()
self
.
timers
(
'forward'
).
start
()
if
self
.
training_dataloader
is
None
:
self
.
tput_timer
.
start
()
loss
=
self
.
module
(
*
inputs
,
**
kwargs
)
# scale loss w.r.t. gradient accumulation if needed
if
self
.
gradient_accumulation_steps
()
>
1
:
loss
=
self
.
_scale_loss
(
loss
)
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'forward'
).
stop
()
self
.
timers
(
'forward_microstep'
).
stop
()
return
loss
def
allreduce_gradients
(
self
,
bucket_size
=
MEMORY_OPT_ALLREDUCE_SIZE
):
if
self
.
is_gradient_accumulation_boundary
():
self
.
buffered_allreduce_fallback
(
elements_per_buffer
=
bucket_size
)
def
backward
(
self
,
loss
,
allreduce_gradients
=
True
):
r
"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
if
self
.
is_gradient_accumulation_boundary
()
and
self
.
tensorboard_enabled
(
)
and
torch
.
distributed
.
get_rank
(
)
==
0
:
# deepspeed tensorboard support for loss
self
.
sample_count
+=
(
self
.
train_micro_batch_size_per_gpu
()
*
torch
.
distributed
.
get_world_size
()
*
self
.
gradient_accumulation_steps
())
self
.
summary_events
=
[
(
f
'Train/Samples/train_loss'
,
loss
.
mean
().
item
()
*
self
.
gradient_accumulation_steps
(),
self
.
sample_count
)
]
for
event
in
self
.
summary_events
:
# write_summary_events
self
.
summary_writer
.
add_scalar
(
event
[
0
],
event
[
1
],
event
[
2
])
self
.
summary_writer
.
flush
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_microstep'
).
start
()
self
.
timers
(
'backward'
).
start
()
assert
self
.
optimizer
is
not
None
,
"must provide optimizer during "
\
"init in order to use backward"
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_inner_microstep'
).
start
()
self
.
timers
(
'backward_inner'
).
start
()
if
self
.
zero_optimization
():
self
.
optimizer
.
backward
(
loss
)
elif
self
.
fp16_enabled
():
self
.
optimizer
.
backward
(
loss
)
# TODO: Use new AMP semantics as below
# with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# scaled_loss.backward()
else
:
loss
.
backward
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_inner'
).
stop
()
self
.
timers
(
'backward_inner_microstep'
).
stop
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_allreduce_microstep'
).
start
()
self
.
timers
(
'backward_allreduce'
).
start
()
if
allreduce_gradients
:
self
.
allreduce_gradients
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'backward_allreduce'
).
stop
()
self
.
timers
(
'backward_allreduce_microstep'
).
stop
()
self
.
timers
(
'backward'
).
stop
()
self
.
timers
(
'backward_microstep'
).
stop
()
def
is_gradient_accumulation_boundary
(
self
):
return
(
self
.
micro_steps
+
1
)
%
\
self
.
gradient_accumulation_steps
()
==
0
def
step
(
self
):
r
"""Execute the weight update step after forward and backward propagation on effective_train_batch
"""
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'step_microstep'
).
start
()
self
.
timers
(
'step'
).
start
()
assert
self
.
optimizer
is
not
None
,
"must provide optimizer during "
\
"init in order to use step"
report_progress
=
self
.
global_rank
==
0
if
self
.
global_rank
else
True
if
self
.
is_gradient_accumulation_boundary
():
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
# Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
overflow
=
False
if
hasattr
(
self
.
optimizer
,
'overflow'
):
overflow
=
self
.
optimizer
.
overflow
if
overflow
:
self
.
skipped_steps
+=
1
else
:
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
step
()
if
report_progress
and
(
self
.
global_steps
+
1
)
%
self
.
steps_per_print
()
==
0
:
self
.
_report_progress
(
self
.
global_steps
+
1
)
self
.
global_steps
+=
1
self
.
tput_timer
.
stop
(
report_progress
)
if
self
.
is_gradient_accumulation_boundary
()
and
self
.
tensorboard_enabled
(
)
and
torch
.
distributed
.
get_rank
()
==
0
:
# deepspeed tensorboard support for lr
self
.
summary_events
=
[(
f
'Train/Samples/lr'
,
self
.
get_lr
()[
0
],
self
.
sample_count
)]
for
event
in
self
.
summary_events
:
# write_summary_events
self
.
summary_writer
.
add_scalar
(
event
[
0
],
event
[
1
],
event
[
2
])
self
.
summary_writer
.
flush
()
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'step'
).
stop
()
self
.
timers
(
'step_microstep'
).
stop
()
self
.
timers
.
log
([
'forward_microstep'
,
'backward_microstep'
,
'backward_inner_microstep'
,
'backward_allreduce_microstep'
,
'step_microstep'
])
if
self
.
is_gradient_accumulation_boundary
():
if
self
.
tensorboard_enabled
()
and
torch
.
distributed
.
get_rank
(
)
==
0
:
# this is done before the log because log resets timers
self
.
summary_events
=
[(
f
'Train/Samples/elapsed_time_ms_forward'
,
self
.
timers
(
'forward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward'
,
self
.
timers
(
'backward'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward_inner'
,
self
.
timers
(
'backward_inner'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_backward_allreduce'
,
self
.
timers
(
'backward_allreduce'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
),
\
(
f
'Train/Samples/elapsed_time_ms_step'
,
self
.
timers
(
'step'
).
elapsed
(
reset
=
False
)
*
1000.0
,
self
.
sample_count
)
]
for
event
in
self
.
summary_events
:
# write_summary_events
self
.
summary_writer
.
add_scalar
(
event
[
0
],
event
[
1
],
event
[
2
])
self
.
summary_writer
.
flush
()
self
.
timers
.
log
([
'forward'
,
'backward'
,
'backward_inner'
,
'backward_allreduce'
,
'step'
])
self
.
micro_steps
+=
1
def
_get_optimizer_param
(
self
,
param_name
):
result
=
[]
if
not
self
.
optimizer
:
return
result
for
group
in
self
.
optimizer
.
param_groups
:
if
param_name
in
group
:
result
.
append
(
group
[
param_name
])
else
:
result
.
append
(
0.0
)
return
result
def
get_lr
(
self
):
return
self
.
_get_optimizer_param
(
'lr'
)
def
get_mom
(
self
):
return
self
.
_get_optimizer_param
(
'betas'
)
def
_report_progress
(
self
,
step
):
lr
=
self
.
get_lr
()
mom
=
self
.
get_mom
()
logging
.
info
(
'rank:{} step={}, skipped={}, lr={}, mom={}'
.
format
(
self
.
global_rank
,
step
,
self
.
skipped_steps
,
lr
,
mom
))
def
allreduce_bucket
(
self
,
bucket
):
tensor
=
flatten
(
bucket
)
tensor_to_allreduce
=
tensor
if
self
.
allreduce_always_fp32
():
tensor_to_allreduce
=
tensor
.
float
()
if
self
.
postscale_gradients
():
if
self
.
gradient_predivide_factor
!=
1.0
:
tensor_to_allreduce
.
mul_
(
1.
/
self
.
gradient_predivide_factor
)
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
data_parallel_group
)
if
self
.
gradient_average
:
if
self
.
gradient_predivide_factor
!=
self
.
dp_world_size
:
tensor_to_allreduce
.
mul_
(
self
.
gradient_predivide_factor
/
self
.
dp_world_size
)
else
:
tensor_to_allreduce
.
div_
(
self
.
dp_world_size
)
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
data_parallel_group
)
if
self
.
allreduce_always_fp32
()
and
tensor
is
not
tensor_to_allreduce
:
tensor
.
copy_
(
tensor_to_allreduce
)
return
tensor
def
allreduce_and_copy
(
self
,
small_bucket
):
allreduced
=
self
.
allreduce_bucket
(
small_bucket
)
for
buf
,
synced
in
zip
(
small_bucket
,
unflatten
(
allreduced
,
small_bucket
)):
buf
.
copy_
(
synced
)
def
allreduce_no_retain
(
self
,
bucket
,
numel_per_bucket
=
500000000
):
small_bucket
=
[]
numel
=
0
for
tensor
in
bucket
:
small_bucket
.
append
(
tensor
)
numel
=
numel
+
tensor
.
numel
()
if
numel
>
numel_per_bucket
:
self
.
allreduce_and_copy
(
small_bucket
)
small_bucket
=
[]
if
len
(
small_bucket
)
>
0
:
self
.
allreduce_and_copy
(
small_bucket
)
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
grads
=
[]
for
param_name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
grad
is
not
None
:
grad_data
=
param
.
grad
.
data
param_name_root
=
param_name
.
split
(
'.'
,
1
)[
0
]
if
self
.
sparse_gradients_enabled
(
)
and
param_name_root
in
self
.
csr_tensor_module_names
:
grads
.
append
(
CSRTensor
(
grad_data
))
else
:
grads
.
append
(
grad_data
)
split_buckets
=
split_half_float_double_csr
(
grads
)
for
i
,
bucket_tuple
in
enumerate
(
split_buckets
):
bucket_type
,
bucket
=
bucket_tuple
if
bucket_type
==
CSRTensor
.
type
():
self
.
csr_allreduce_no_retain
(
bucket
)
else
:
self
.
allreduce_no_retain
(
bucket
,
numel_per_bucket
=
elements_per_buffer
)
def
csr_allreduce_no_retain
(
self
,
bucket
):
allreduced_csrs
=
self
.
csr_allreduce_bucket
(
bucket
)
# Densify csr tensor and copy back to original location
for
csr
in
allreduced_csrs
:
dense_tensor
=
csr
.
to_dense
()
csr
.
orig_dense_tensor
.
copy_
(
dense_tensor
)
def
csr_allreduce_bucket
(
self
,
bucket
):
csr_list
=
[]
for
csr
in
bucket
:
csr_list
.
append
(
self
.
csr_allreduce
(
csr
))
return
csr_list
def
csr_allreduce
(
self
,
csr
):
# Pre-divide for fp16 stability
csr
.
values
.
div_
(
self
.
dp_world_size
)
indices_device_list
=
self
.
csr_all_gather
(
csr
.
indices
)
values_device_list
=
self
.
csr_all_gather
(
csr
.
values
)
csr
.
indices
=
torch
.
cat
(
indices_device_list
)
csr
.
values
=
torch
.
cat
(
values_device_list
)
return
csr
def
csr_all_gather
(
self
,
value
):
my_size
=
torch
.
LongTensor
([
value
.
size
()[
0
]]).
cuda
()
all_sizes
=
self
.
all_gather_scalar
(
my_size
)
max_size
=
torch
.
cat
(
all_sizes
).
max
()
fill_size
=
(
max_size
-
my_size
)
assert
value
.
dim
()
in
[
1
,
2
]
if
value
.
dim
()
==
1
:
if
fill_size
>
0
:
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
)])
tensor_list
=
[
value
.
new_zeros
(
max_size
)
for
_
in
range
(
dist
.
get_world_size
())
]
else
:
if
fill_size
>
0
:
value
=
torch
.
cat
([
value
,
value
.
new_zeros
(
fill_size
,
value
.
size
()[
1
])])
tensor_list
=
[
value
.
new_zeros
(
max_size
,
value
.
size
()[
1
])
for
_
in
range
(
dist
.
get_world_size
())
]
dist
.
all_gather
(
tensor_list
,
value
,
group
=
self
.
data_parallel_group
)
tensors
=
[]
for
dev_idx
,
t
in
enumerate
(
tensor_list
):
size
=
all_sizes
[
dev_idx
][
0
]
tensors
.
append
(
t
.
index_select
(
0
,
torch
.
LongTensor
(
range
(
size
)).
cuda
()))
return
tensors
def
all_gather_scalar
(
self
,
value
):
tensor_list
=
[
value
.
new_zeros
(
value
.
size
())
for
_
in
range
(
self
.
dp_world_size
)]
dist
.
all_gather
(
tensor_list
,
value
,
group
=
self
.
data_parallel_group
)
return
tensor_list
def
module_state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
sd
def
load_module_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
def
_get_zero_ckpt_name
(
self
,
checkpoints_path
,
tag
):
mp_rank
=
0
if
self
.
mpu
is
None
else
self
.
mpu
.
get_model_parallel_rank
()
pp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
optimizer
.
dp_process_group
)
filename
=
'zero_pp_rank_{}'
.
format
(
pp_rank
)
zero_ckpt_name
=
os
.
path
.
join
(
checkpoints_path
,
str
(
tag
),
filename
+
'_mp_rank_{:02d}'
.
format
(
mp_rank
)
+
'optim_states.pt'
)
return
zero_ckpt_name
def
_get_ckpt_name
(
self
,
checkpoints_path
,
tag
):
mp_rank
=
0
if
self
.
mpu
is
None
else
self
.
mpu
.
get_model_parallel_rank
()
ckpt_name
=
os
.
path
.
join
(
checkpoints_path
,
str
(
tag
),
'mp_rank_{:02d}'
.
format
(
mp_rank
)
+
'_model_states.pt'
)
return
ckpt_name
def
_ensure_directory_exists
(
self
,
filename
):
dirname
=
os
.
path
.
dirname
(
filename
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
def
load_checkpoint
(
self
,
load_dir
,
tag
):
r
"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""
load_path
,
client_states
=
self
.
_load_checkpoint
(
load_dir
,
tag
)
if
self
.
zero_optimization
()
and
load_path
is
not
None
:
self
.
_load_zero_checkpoint
(
load_dir
,
tag
)
return
load_path
,
client_states
def
_load_checkpoint
(
self
,
load_dir
,
tag
):
load_path
=
self
.
_get_ckpt_name
(
load_dir
,
tag
)
if
not
os
.
path
.
exists
(
load_path
):
logging
.
warn
(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.
format
(
load_path
))
return
None
,
None
logging
.
info
(
'Loading checkpoint: {}'
.
format
(
load_path
))
checkpoint
=
torch
.
load
(
load_path
,
map_location
=
lambda
storage
,
loc
:
storage
)
self
.
load_module_state_dict
(
checkpoint
[
'module'
])
if
not
self
.
zero_optimization
():
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
self
.
csr_tensor_module_names
=
checkpoint
[
'csr_tensor_module_names'
]
self
.
global_steps
=
checkpoint
[
'global_steps'
]
self
.
skipped_steps
=
checkpoint
[
'skipped_steps'
]
deepspeed_states
=
[
'module'
,
'optimizer'
,
'csr_tensor_module_names'
,
'skipped_steps'
,
'global_step'
]
client_state
=
{
key
:
value
for
key
,
value
in
checkpoint
.
items
()
if
not
key
in
deepspeed_states
}
return
load_path
,
client_state
def
_load_zero_checkpoint
(
self
,
load_dir
,
tag
):
zero_checkpoint_name
=
self
.
_get_zero_ckpt_name
(
load_dir
,
tag
)
if
not
os
.
path
.
exists
(
zero_checkpoint_name
):
logging
.
warn
(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.
format
(
zero_checkpoint_name
))
return
None
zero_sd
=
torch
.
load
(
zero_checkpoint_name
,
map_location
=
'cpu'
)
self
.
optimizer
.
load_state_dict
(
zero_sd
[
'optimizer_state_dict'
])
logging
.
info
(
'loading zero checkpoint {}'
.
format
(
zero_checkpoint_name
))
def
save_checkpoint
(
self
,
save_dir
,
tag
,
client_state
=
{}):
r
"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
client_state: Optional. State dictionary used for saving required training states in the client code.
"""
#This is to make sure the checkpoint names are created without collision
#There seems to be issue creating them in parallel
self
.
_create_checkpoint_files
(
save_dir
,
tag
)
try
:
if
self
.
save_non_zero_checkpoint
:
self
.
_save_checkpoint
(
save_dir
,
tag
,
client_state
=
client_state
)
if
self
.
save_zero_checkpoint
:
self
.
_save_zero_checkpoint
(
save_dir
,
tag
)
except
:
logging
.
error
(
f
'Failed Saving model checkpoint to
{
save_dir
}
with tag
{
tag
}
'
)
return
False
return
True
def
_create_checkpoint_files
(
self
,
save_dir
,
tag
):
#checkpoint files are created sequentially
for
rank
in
range
(
dist
.
get_world_size
()):
if
rank
==
dist
.
get_rank
():
try
:
if
self
.
save_non_zero_checkpoint
:
checkpoint_name
=
self
.
_get_ckpt_name
(
save_dir
,
tag
)
self
.
_ensure_directory_exists
(
checkpoint_name
)
if
self
.
save_zero_checkpoint
:
checkpoint_name
=
self
.
_get_zero_ckpt_name
(
save_dir
,
tag
)
self
.
_ensure_directory_exists
(
checkpoint_name
)
except
:
logging
.
error
(
f
'Failed Saving model checkpoint to
{
save_dir
}
with tag
{
tag
}
'
)
return
False
dist
.
barrier
()
def
_save_checkpoint
(
self
,
save_dir
,
tag
,
client_state
=
{}):
save_path
=
self
.
_get_ckpt_name
(
save_dir
,
tag
)
#self._ensure_directory_exists(save_path)
state
=
{
'module'
:
self
.
module_state_dict
(),
'optimizer'
:
self
.
optimizer
.
state_dict
()
if
self
.
optimizer
and
not
self
.
zero_optimization
()
else
None
,
'lr_scheduler'
:
self
.
lr_scheduler
.
state_dict
()
if
self
.
lr_scheduler
is
not
None
else
None
,
'csr_tensor_module_names'
:
self
.
csr_tensor_module_names
,
'skipped_steps'
:
self
.
skipped_steps
,
'global_steps'
:
self
.
global_steps
,
}
state
.
update
(
client_state
)
logging
.
info
(
'Saving model checkpoint: {}'
.
format
(
save_path
))
torch
.
save
(
state
,
save_path
)
def
_save_zero_checkpoint
(
self
,
save_path
,
tag
):
try
:
zero_checkpoint_name
=
self
.
_get_zero_ckpt_name
(
save_path
,
tag
)
#self._ensure_directory_exists(zero_checkpoint_name)
except
:
logging
.
error
(
f
'Failed Saving Zero model checkpoint to
{
save_path
}
with tag
{
tag
}
'
)
zero_sd
=
{
'optimizer_state_dict'
:
self
.
optimizer
.
state_dict
()}
torch
.
save
(
zero_sd
,
zero_checkpoint_name
)
logging
.
info
(
'zero checkpoint saved {}'
.
format
(
zero_checkpoint_name
))
deepspeed/pt/deepspeed_lr_schedules.py
0 → 100644
View file @
7a9fbe67
"""
Copyright 2019 The Microsoft DeepSpeed Team
Implementation of learning rate schedules.
Taken and modified from PyTorch v1.0.1 source
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
"""
import
argparse
from
torch.optim
import
Optimizer
from
typing
import
Union
,
List
import
math
from
deepspeed.pt.deepspeed_constants
import
*
LR_SCHEDULE
=
'lr_schedule'
LR_RANGE_TEST
=
'LRRangeTest'
ONE_CYCLE
=
'OneCycle'
WARMUP_LR
=
'WarmupLR'
VALID_LR_SCHEDULES
=
[
LR_RANGE_TEST
,
ONE_CYCLE
,
WARMUP_LR
]
LR_RANGE_TEST_MIN_LR
=
'lr_range_test_min_lr'
LR_RANGE_TEST_STEP_RATE
=
'lr_range_test_step_rate'
LR_RANGE_TEST_STEP_SIZE
=
'lr_range_test_step_size'
LR_RANGE_TEST_STAIRCASE
=
'lr_range_test_staircase'
EDGE_VALUE
=
'edge_value'
MID_VALUE
=
'mid_value'
CYCLE_FIRST_STEP_SIZE
=
'cycle_first_step_size'
CYCLE_FIRST_STAIR_COUNT
=
'cycle_first_stair_count'
CYCLE_SECOND_STEP_SIZE
=
'cycle_second_step_size'
CYCLE_SECOND_STAIR_COUNT
=
'cycle_second_stair_count'
DECAY_STEP_SIZE
=
'decay_step_size'
CYCLE_MIN_LR
=
'cycle_min_lr'
CYCLE_MAX_LR
=
'cycle_max_lr'
DECAY_LR_RATE
=
'decay_lr_rate'
CYCLE_MIN_MOM
=
'cycle_min_mom'
CYCLE_MAX_MOM
=
'cycle_max_mom'
DECAY_MOM_RATE
=
'decay_mom_rate'
WARMUP_MIN_LR
=
'warmup_min_lr'
WARMUP_MAX_LR
=
'warmup_max_lr'
WARMUP_NUM_STEPS
=
'warmup_num_steps'
def
add_tuning_arguments
(
parser
):
group
=
parser
.
add_argument_group
(
'Convergence Tuning'
,
'Convergence tuning configurations'
)
# LR scheduler
group
.
add_argument
(
'--lr_schedule'
,
type
=
str
,
default
=
None
,
help
=
'LR schedule for training.'
)
# Learning rate range test
group
.
add_argument
(
"--lr_range_test_min_lr"
,
type
=
float
,
default
=
0.001
,
help
=
'Starting lr value.'
)
group
.
add_argument
(
"--lr_range_test_step_rate"
,
type
=
float
,
default
=
1.0
,
help
=
'scaling rate for LR range test.'
)
group
.
add_argument
(
"--lr_range_test_step_size"
,
type
=
int
,
default
=
1000
,
help
=
'training steps per LR change.'
)
group
.
add_argument
(
"--lr_range_test_staircase"
,
type
=
bool
,
default
=
False
,
help
=
'use staircase scaling for LR range test.'
)
# OneCycle schedule
group
.
add_argument
(
"--cycle_first_step_size"
,
type
=
int
,
default
=
1000
,
help
=
'size of first step of 1Cycle schedule (training steps).'
)
group
.
add_argument
(
"--cycle_first_stair_count"
,
type
=
int
,
default
=-
1
,
help
=
'first stair count for 1Cycle schedule.'
)
group
.
add_argument
(
"--cycle_second_step_size"
,
type
=
int
,
default
=-
1
,
help
=
'size of second step of 1Cycle schedule (default first_step_size).'
)
group
.
add_argument
(
"--cycle_second_stair_count"
,
type
=
int
,
default
=-
1
,
help
=
'second stair count for 1Cycle schedule.'
)
group
.
add_argument
(
"--decay_step_size"
,
type
=
int
,
default
=
1000
,
help
=
'size of intervals for applying post cycle decay (training steps).'
)
# 1Cycle LR
group
.
add_argument
(
"--cycle_min_lr"
,
type
=
float
,
default
=
0.01
,
help
=
'1Cycle LR lower bound.'
)
group
.
add_argument
(
"--cycle_max_lr"
,
type
=
float
,
default
=
0.1
,
help
=
'1Cycle LR upper bound.'
)
group
.
add_argument
(
"--decay_lr_rate"
,
type
=
float
,
default
=
0.0
,
help
=
'post cycle LR decay rate.'
)
# 1Cycle Momentum
group
.
add_argument
(
'--cycle_momentum'
,
default
=
False
,
action
=
'store_true'
,
help
=
'Enable 1Cycle momentum schedule.'
)
group
.
add_argument
(
"--cycle_min_mom"
,
type
=
float
,
default
=
0.8
,
help
=
'1Cycle momentum lower bound.'
)
group
.
add_argument
(
"--cycle_max_mom"
,
type
=
float
,
default
=
0.9
,
help
=
'1Cycle momentum upper bound.'
)
group
.
add_argument
(
"--decay_mom_rate"
,
type
=
float
,
default
=
0.0
,
help
=
'post cycle momentum decay rate.'
)
# Warmup LR
group
.
add_argument
(
'--warmup_min_lr'
,
type
=
float
,
default
=
0
,
help
=
'WarmupLR minimum/initial LR value'
)
group
.
add_argument
(
'--warmup_max_lr'
,
type
=
float
,
default
=
0.001
,
help
=
'WarmupLR maximum LR value.'
)
group
.
add_argument
(
'--warmup_num_steps'
,
type
=
int
,
default
=
1000
,
help
=
'WarmupLR step count for LR warmup.'
)
return
parser
def
parse_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
=
add_tuning_arguments
(
parser
)
lr_sched_args
,
unknown_args
=
parser
.
parse_known_args
()
return
lr_sched_args
,
unknown_args
def
override_lr_range_test_params
(
args
,
params
):
if
hasattr
(
args
,
LR_RANGE_TEST_MIN_LR
)
and
args
.
lr_range_test_min_lr
is
not
None
:
params
[
LR_RANGE_TEST_MIN_LR
]
=
args
.
lr_range_test_min_lr
if
hasattr
(
args
,
LR_RANGE_TEST_STEP_RATE
)
and
args
.
lr_range_test_step_rate
is
not
None
:
params
[
LR_RANGE_TEST_STEP_RATE
]
=
args
.
lr_range_test_step_rate
if
hasattr
(
args
,
LR_RANGE_TEST_STEP_SIZE
)
and
args
.
lr_range_test_step_size
is
not
None
:
params
[
LR_RANGE_TEST_STEP_SIZE
]
=
args
.
lr_range_test_step_size
if
hasattr
(
args
,
LR_RANGE_TEST_STAIRCASE
)
and
args
.
lr_range_test_staircase
is
not
None
:
params
[
LR_RANGE_TEST_STAIRCASE
]
=
args
.
lr_range_test_staircase
def
override_1cycle_params
(
args
,
params
):
if
hasattr
(
args
,
CYCLE_FIRST_STEP_SIZE
)
and
args
.
cycle_first_step_size
is
not
None
:
params
[
CYCLE_FIRST_STEP_SIZE
]
=
args
.
cycle_first_step_size
if
hasattr
(
args
,
CYCLE_FIRST_STAIR_COUNT
)
and
args
.
cycle_first_stair_count
is
not
None
:
params
[
CYCLE_FIRST_STAIR_COUNT
]
=
args
.
cycle_first_stair_count
if
hasattr
(
args
,
CYCLE_SECOND_STEP_SIZE
)
and
args
.
cycle_second_step_size
is
not
None
:
params
[
CYCLE_SECOND_STEP_SIZE
]
=
args
.
cycle_second_step_size
if
hasattr
(
args
,
CYCLE_SECOND_STAIR_COUNT
)
and
args
.
cycle_second_stair_count
is
not
None
:
params
[
CYCLE_SECOND_STAIR_COUNT
]
=
args
.
cycle_second_stair_count
if
hasattr
(
args
,
DECAY_STEP_SIZE
)
and
args
.
decay_step_size
is
not
None
:
params
[
DECAY_STEP_SIZE
]
=
args
.
decay_step_size
# 1Cycle LR params
if
hasattr
(
args
,
CYCLE_MIN_LR
)
and
args
.
cycle_min_lr
is
not
None
:
params
[
CYCLE_MIN_LR
]
=
args
.
cycle_min_lr
if
hasattr
(
args
,
CYCLE_MAX_LR
)
and
args
.
cycle_max_lr
is
not
None
:
params
[
CYCLE_MAX_LR
]
=
args
.
cycle_max_lr
if
hasattr
(
args
,
DECAY_LR_RATE
)
and
args
.
decay_lr_rate
is
not
None
:
params
[
DECAY_LR_RATE
]
=
args
.
decay_lr_rate
# 1Cycle MOM params
if
hasattr
(
args
,
CYCLE_MIN_MOM
)
and
args
.
cycle_min_mom
is
not
None
:
params
[
CYCLE_MIN_MOM
]
=
args
.
cycle_min_mom
if
hasattr
(
args
,
CYCLE_MAX_MOM
)
and
args
.
cycle_max_mom
is
not
None
:
params
[
CYCLE_MAX_MOM
]
=
args
.
cycle_max_mom
if
hasattr
(
args
,
DECAY_MOM_RATE
)
and
args
.
decay_mom_rate
is
not
None
:
params
[
DECAY_MOM_RATE
]
=
args
.
decay_mom_rate
def
override_warmupLR_params
(
args
,
params
):
if
hasattr
(
args
,
WARMUP_MIN_LR
)
and
args
.
warmup_min_lr
is
not
None
:
params
[
WARMUP_MIN_LR
]
=
args
.
warmup_min_lr
if
hasattr
(
args
,
WARMUP_MAX_LR
)
and
args
.
warmup_max_lr
is
not
None
:
params
[
WARMUP_MAX_LR
]
=
args
.
warmup_max_lr
if
hasattr
(
args
,
WARMUP_NUM_STEPS
)
and
args
.
warmup_num_steps
is
not
None
:
params
[
WARMUP_NUM_STEPS
]
=
args
.
warmup_num_steps
def
override_params
(
args
,
params
):
# LR range test params
override_lr_range_test_params
(
args
,
params
)
# 1Cycle params
override_1cycle_params
(
args
,
params
)
# WarmupLR params
override_warmupLR_params
(
args
,
params
)
def
get_config_from_args
(
args
):
if
not
hasattr
(
args
,
LR_SCHEDULE
)
or
args
.
lr_schedule
is
None
:
return
None
,
'--{} not specified on command line'
.
format
(
LR_SCHEDULE
)
if
not
args
.
lr_schedule
in
VALID_LR_SCHEDULES
:
return
None
,
'{} is not supported LR schedule'
.
format
(
args
.
lr_schedule
)
config
=
{}
config
[
'type'
]
=
args
.
lr_schedule
config
[
'params'
]
=
{}
if
args
.
lr_schedule
==
LR_RANGE_TEST
:
override_lr_range_test_params
(
args
,
config
[
'params'
])
elif
args
.
lr_schedule
==
ONE_CYCLE
:
override_1cycle_params
(
args
,
config
[
'params'
])
else
:
override_warmupLR_params
(
args
,
config
[
'params'
])
return
config
,
None
def
get_lr_from_config
(
config
):
if
not
'type'
in
config
:
return
None
,
'LR schedule type not defined in config'
if
not
'params'
in
config
:
return
None
,
'LR schedule params not defined in config'
lr_schedule
=
config
[
'type'
]
lr_params
=
config
[
'params'
]
if
not
lr_schedule
in
VALID_LR_SCHEDULES
:
return
None
,
'{} is not a valid LR schedule'
.
format
(
lr_schedule
)
if
lr_schedule
==
LR_RANGE_TEST
:
return
lr_params
[
LR_RANGE_TEST_MIN_LR
],
''
elif
lr_schedule
==
ONE_CYCLE
:
return
lr_params
[
CYCLE_MAX_LR
],
''
else
:
# Warmup LR
return
lr_params
[
WARMUP_MAX_LR
],
''
class
LRRangeTest
(
object
):
"""Sets the learning rate of each parameter group according to
learning rate range test (LRRT) policy. The policy increases learning
rate starting from a base value with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
configure the LR boundaries for Cylic LR schedules.
LRRT changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_range_test_min_lr (float or list): Initial learning rate which is the
lower boundary in the range test for each parameter group.
lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
lr_range_test_staircase (bool): Scale in staircase fashion, rather than continous. Default: False.
last_batch_iteration (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_batch_iteration=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.LRRangeTest(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
_A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
https://arxiv.org/abs/1803.09820
"""
def
__init__
(
self
,
optimizer
:
Optimizer
,
lr_range_test_min_lr
:
float
=
1e-3
,
lr_range_test_step_size
:
int
=
2000
,
lr_range_test_step_rate
:
float
=
1.0
,
lr_range_test_staircase
:
bool
=
False
,
last_batch_iteration
:
int
=
-
1
):
if
not
isinstance
(
optimizer
,
Optimizer
):
raise
TypeError
(
'{} is not an Optimizer'
.
format
(
type
(
optimizer
).
__name__
))
self
.
optimizer
=
optimizer
if
isinstance
(
lr_range_test_min_lr
,
list
)
or
isinstance
(
lr_range_test_min_lr
,
tuple
):
if
len
(
lr_range_test_min_lr
)
!=
len
(
optimizer
.
param_groups
):
raise
ValueError
(
"expected {} lr_range_test_min_lr, got {}"
.
format
(
len
(
optimizer
.
param_groups
),
len
(
lr_range_test_min_lr
)))
self
.
min_lr
=
list
(
lr_range_test_min_lr
)
else
:
self
.
min_lr
=
[
lr_range_test_min_lr
]
*
len
(
optimizer
.
param_groups
)
self
.
step_size
=
lr_range_test_step_size
self
.
step_rate
=
lr_range_test_step_rate
self
.
last_batch_iteration
=
last_batch_iteration
self
.
staircase
=
lr_range_test_staircase
self
.
interval_fn
=
self
.
_staircase_interval
if
lr_range_test_staircase
else
self
.
_continous_interval
if
last_batch_iteration
==
-
1
:
self
.
_update_optimizer
(
self
.
min_lr
)
def
_staircase_interval
(
self
):
return
math
.
floor
(
float
(
self
.
last_batch_iteration
)
/
self
.
step_size
)
def
_continous_interval
(
self
):
return
float
(
self
.
last_batch_iteration
)
/
self
.
step_size
def
_get_increase
(
self
):
return
(
1
+
self
.
step_rate
*
self
.
interval_fn
())
def
get_lr
(
self
):
lr_increase
=
self
.
_get_increase
()
return
[
lr_range_test_min_lr
*
lr_increase
for
lr_range_test_min_lr
in
self
.
min_lr
]
def
_update_optimizer
(
self
,
group_lrs
):
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
group_lrs
):
param_group
[
'lr'
]
=
lr
def
step
(
self
,
batch_iteration
=
None
):
if
batch_iteration
is
None
:
batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
batch_iteration
self
.
_update_optimizer
(
self
.
get_lr
())
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
def
load_state_dict
(
self
,
sd
):
self
.
last_batch_iteration
=
sd
[
'last_batch_iteration'
]
class
OneCycle
(
object
):
"""Sets the learning rate of each parameter group according to
1Cycle learning rate policy (1CLR). 1CLR is a variation of the
Cyclical Learning Rate (CLR) policy that involves one cycle followed by
decay. The policy simultaneously cycles the learning rate (and momentum)
between two boundaries with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters`_.
1CLR policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This implementation was adapted from the github repo: `pytorch/pytorch`_
Args:
optimizer (Optimizer): Wrapped optimizer.
cycle_min_lr (float or list): Initial learning rate which is the
lower boundary in the cycle for each parameter group.
cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
The lr at any cycle is the sum of cycle_min_lr
and some scaling of the amplitude; therefore
cycle_max_lr may not actually be reached depending on
scaling function.
decay_lr_rate(float): Decay rate for learning rate. Default: 0.
cycle_first_step_size (int): Number of training iterations in the
increasing half of a cycle. Default: 2000
cycle_second_step_size (int): Number of training iterations in the
decreasing half of a cycle. If cycle_second_step_size is None,
it is set to cycle_first_step_size. Default: None
cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
Default: True
cycle_min_mom (float or list): Initial momentum which is the
lower boundary in the cycle for each parameter group.
Default: 0.8
cycle_max_mom (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
The momentum at any cycle is the difference of cycle_max_mom
and some scaling of the amplitude; therefore
cycle_min_mom may not actually be reached depending on
scaling function. Default: 0.9
decay_mom_rate (float): Decay rate for momentum. Default: 0.
last_batch_iteration (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_batch_iteration=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.OneCycle(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
"""
def
__init__
(
self
,
optimizer
,
cycle_min_lr
,
cycle_max_lr
,
decay_lr_rate
=
0.
,
cycle_first_step_size
=
2000
,
cycle_second_step_size
=
None
,
cycle_first_stair_count
=
0
,
cycle_second_stair_count
=
None
,
decay_step_size
=
0
,
cycle_momentum
=
True
,
cycle_min_mom
=
0.8
,
cycle_max_mom
=
0.9
,
decay_mom_rate
=
0.
,
last_batch_iteration
=-
1
):
if
not
isinstance
(
optimizer
,
Optimizer
):
raise
TypeError
(
'{} is not an Optimizer'
.
format
(
type
(
optimizer
).
__name__
))
self
.
optimizer
=
optimizer
self
.
min_lrs
=
[
cycle_min_lr
]
*
len
(
optimizer
.
param_groups
)
if
last_batch_iteration
==
-
1
:
for
lr
,
group
in
zip
(
self
.
min_lrs
,
optimizer
.
param_groups
):
group
[
'lr'
]
=
lr
self
.
max_lrs
=
[
cycle_max_lr
]
*
len
(
optimizer
.
param_groups
)
cycle_first_step_size
=
float
(
cycle_first_step_size
)
cycle_second_step_size
=
float
(
cycle_second_step_size
)
if
cycle_second_step_size
is
not
None
else
cycle_first_step_size
self
.
total_size
=
cycle_first_step_size
+
cycle_second_step_size
self
.
step_ratio
=
cycle_first_step_size
/
self
.
total_size
self
.
first_stair_count
=
cycle_first_stair_count
self
.
second_stair_count
=
cycle_first_stair_count
if
cycle_second_stair_count
is
None
else
cycle_second_stair_count
self
.
decay_lr_rate
=
decay_lr_rate
self
.
decay_mom_rate
=
decay_mom_rate
self
.
decay_step_size
=
decay_step_size
self
.
min_moms
=
[(
cycle_min_mom
,
0.99
)]
*
len
(
optimizer
.
param_groups
)
self
.
max_moms
=
[(
cycle_max_mom
,
0.99
)]
*
len
(
optimizer
.
param_groups
)
self
.
cycle_momentum
=
cycle_momentum
self
.
last_batch_iteration
=
last_batch_iteration
if
cycle_momentum
:
if
'betas'
not
in
optimizer
.
defaults
:
raise
ValueError
(
'optimizer must support betas with `cycle_momentum` option enabled'
)
if
last_batch_iteration
==
-
1
:
for
momentum
,
group
in
zip
(
self
.
min_moms
,
optimizer
.
param_groups
):
group
[
'betas'
]
=
momentum
def
_get_cycle_lr
(
self
):
cycle
=
math
.
floor
(
1
+
self
.
last_batch_iteration
/
self
.
total_size
)
x
=
1.
+
self
.
last_batch_iteration
/
self
.
total_size
-
cycle
if
x
<=
self
.
step_ratio
:
scale_factor
=
x
/
self
.
step_ratio
else
:
scale_factor
=
(
x
-
1
)
/
(
self
.
step_ratio
-
1
)
lrs
=
[]
for
cycle_min_lr
,
cycle_max_lr
in
zip
(
self
.
min_lrs
,
self
.
max_lrs
):
base_height
=
(
cycle_max_lr
-
cycle_min_lr
)
*
scale_factor
lr
=
cycle_min_lr
+
base_height
lrs
.
append
(
lr
)
if
self
.
cycle_momentum
:
momentums
=
[]
for
base_betas
,
max_betas
in
zip
(
self
.
min_moms
,
self
.
max_moms
):
cycle_min_mom
=
base_betas
[
0
]
cycle_max_mom
=
max_betas
[
0
]
base_height
=
(
cycle_max_mom
-
cycle_min_mom
)
*
scale_factor
momentum
=
cycle_max_mom
-
base_height
momentums
.
append
((
momentum
,
base_betas
[
1
]))
for
param_group
,
momentum
in
zip
(
self
.
optimizer
.
param_groups
,
momentums
):
param_group
[
'betas'
]
=
momentum
return
lrs
def
_get_decay_lr
(
self
,
decay_batch_iteration
):
"""Calculates the learning rate at batch index. This function is used
after the cycle completes and post cycle decaying of lr/mom is enabled.
This function treats `self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
decay_interval
=
decay_batch_iteration
/
self
.
decay_step_size
lr_decay_factor
=
(
1
+
self
.
decay_lr_rate
*
decay_interval
)
lrs
=
[
cycle_min_lr
*
lr_decay_factor
for
cycle_min_lr
in
self
.
min_lrs
]
if
self
.
cycle_momentum
:
mom_decay_factor
=
(
1
+
self
.
decay_mom_rate
*
decay_interval
)
momentums
=
[(
beta0
*
mom_decay_factor
,
beta1
)
for
beta0
,
beta1
in
self
.
max_moms
]
for
param_group
,
momentum
in
zip
(
self
.
optimizer
.
param_groups
,
momentums
):
param_group
[
'betas'
]
=
momentum
return
lrs
def
get_lr
(
self
):
"""Calculates the learning rate at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
if
self
.
last_batch_iteration
<=
self
.
total_size
:
return
self
.
_get_cycle_lr
()
else
:
return
self
.
_get_decay_lr
(
self
.
last_batch_iteration
-
self
.
total_size
)
def
step
(
self
,
batch_iteration
=
None
):
if
batch_iteration
is
None
:
batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
batch_iteration
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
def
load_state_dict
(
self
,
sd
):
self
.
last_batch_iteration
=
sd
[
'last_batch_iteration'
]
class
WarmupLR
(
object
):
"""Increase the learning rate of each parameter group from min lr to max lr
over warmup_num_steps steps, and then fix at max lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_min_lr (float or list): minimum learning rate. Default: 0
warmup_max_lr (float or list): maximum learning rate. Default: 0.001
warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
last_batch_iteration (int): The index of the last batch. Default: -1.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.WarmupLR(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
"""
def
__init__
(
self
,
optimizer
:
Optimizer
,
warmup_min_lr
:
float
=
0.0
,
warmup_max_lr
:
float
=
0.001
,
warmup_num_steps
:
int
=
1000
,
last_batch_iteration
:
int
=
-
1
):
self
.
optimizer
=
optimizer
self
.
min_lrs
=
self
.
_format_param
(
optimizer
,
warmup_min_lr
,
"min_lr"
)
self
.
max_lrs
=
self
.
_format_param
(
optimizer
,
warmup_max_lr
,
"max_lr"
)
self
.
delta_lrs
=
[
big
-
small
for
big
,
small
in
zip
(
self
.
max_lrs
,
self
.
min_lrs
)]
self
.
warmup_num_steps
=
warmup_num_steps
self
.
inverse_log_warm_up
=
1.0
/
math
.
log
(
warmup_num_steps
)
self
.
last_batch_iteration
=
last_batch_iteration
def
get_lr
(
self
):
gamma
=
self
.
_get_gamma
()
return
[
min_lr
+
(
delta_lr
*
gamma
)
for
min_lr
,
delta_lr
in
zip
(
self
.
min_lrs
,
self
.
delta_lrs
)
]
def
step
(
self
,
last_batch_iteration
=
None
):
if
last_batch_iteration
is
None
:
last_batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
last_batch_iteration
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
def
load_state_dict
(
self
,
sd
):
self
.
last_batch_iteration
=
sd
[
'last_batch_iteration'
]
def
_get_gamma
(
self
):
if
self
.
last_batch_iteration
<
self
.
warmup_num_steps
:
return
self
.
inverse_log_warm_up
*
math
.
log
(
self
.
last_batch_iteration
+
1
)
else
:
return
1.0
def
_format_param
(
self
,
optimizer
,
param_value
,
param_name
):
if
isinstance
(
param_value
,
list
)
or
isinstance
(
param_value
,
tuple
):
if
len
(
param_value
)
!=
len
(
optimizer
.
param_groups
):
raise
ValueError
(
"expected {} value for {}, got {}"
.
format
(
len
(
optimizer
.
param_groups
),
param_name
,
FileNotFoundError
(
param_value
)))
return
list
(
param_value
)
else
:
return
[
param_value
]
*
len
(
optimizer
.
param_groups
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment