Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
21cb6b39
Commit
21cb6b39
authored
Jul 24, 2022
by
Guolin Ke
Browse files
warning about missing keys in loading model
parent
8da5eaaf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
47 deletions
+76
-47
unicore/trainer.py
unicore/trainer.py
+76
-47
No files found.
unicore/trainer.py
View file @
21cb6b39
...
@@ -27,6 +27,7 @@ from unicore.utils import tensor_tree_map
...
@@ -27,6 +27,7 @@ from unicore.utils import tensor_tree_map
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
ExponentialMovingAverage
:
class
ExponentialMovingAverage
:
"""
"""
Maintains moving averages of parameters with exponential decay
Maintains moving averages of parameters with exponential decay
...
@@ -164,7 +165,7 @@ class Trainer(object):
...
@@ -164,7 +165,7 @@ class Trainer(object):
else
:
else
:
self
.
cuda_env
=
None
self
.
cuda_env
=
None
self
.
cuda_env_arr
=
None
self
.
cuda_env_arr
=
None
# add ema
# add ema
if
args
.
ema_decay
>
0
and
self
.
data_parallel_rank
==
0
:
if
args
.
ema_decay
>
0
and
self
.
data_parallel_rank
==
0
:
self
.
ema
=
ExponentialMovingAverage
(
self
.
_model
,
decay
=
args
.
ema_decay
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
_model
,
decay
=
args
.
ema_decay
)
...
@@ -207,9 +208,7 @@ class Trainer(object):
...
@@ -207,9 +208,7 @@ class Trainer(object):
@
property
@
property
def
use_distributed_wrapper
(
self
)
->
bool
:
def
use_distributed_wrapper
(
self
)
->
bool
:
return
(
return
self
.
data_parallel_world_size
>
1
self
.
data_parallel_world_size
>
1
)
@
property
@
property
def
should_save_checkpoint_on_current_rank
(
self
)
->
bool
:
def
should_save_checkpoint_on_current_rank
(
self
)
->
bool
:
...
@@ -224,10 +223,7 @@ class Trainer(object):
...
@@ -224,10 +223,7 @@ class Trainer(object):
@
property
@
property
def
loss
(
self
):
def
loss
(
self
):
if
self
.
_wrapped_loss
is
None
:
if
self
.
_wrapped_loss
is
None
:
if
(
if
utils
.
has_parameters
(
self
.
_loss
)
and
self
.
use_distributed_wrapper
:
utils
.
has_parameters
(
self
.
_loss
)
and
self
.
use_distributed_wrapper
):
self
.
_wrapped_loss
=
models
.
DistributedUnicoreModel
(
self
.
_wrapped_loss
=
models
.
DistributedUnicoreModel
(
self
.
args
,
self
.
args
,
self
.
_loss
,
self
.
_loss
,
...
@@ -281,7 +277,7 @@ class Trainer(object):
...
@@ -281,7 +277,7 @@ class Trainer(object):
"please switch to FP32 which is likely to be faster"
"please switch to FP32 which is likely to be faster"
)
)
self
.
_optimizer
=
optim
.
FP16Optimizer
.
build_optimizer
(
self
.
args
,
params
)
self
.
_optimizer
=
optim
.
FP16Optimizer
.
build_optimizer
(
self
.
args
,
params
)
if
self
.
args
.
allreduce_fp32_grad
:
if
self
.
args
.
allreduce_fp32_grad
:
assert
self
.
args
.
ddp_backend
==
"no_c10d"
assert
self
.
args
.
ddp_backend
==
"no_c10d"
if
self
.
args
.
per_sample_clip_norm
>
0
:
if
self
.
args
.
per_sample_clip_norm
>
0
:
...
@@ -290,7 +286,7 @@ class Trainer(object):
...
@@ -290,7 +286,7 @@ class Trainer(object):
if
self
.
cuda
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
7
:
if
self
.
cuda
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
7
:
logger
.
info
(
"NOTE: your device may support faster training with --fp16"
)
logger
.
info
(
"NOTE: your device may support faster training with --fp16"
)
self
.
_optimizer
=
optim
.
build_optimizer
(
self
.
args
,
params
)
self
.
_optimizer
=
optim
.
build_optimizer
(
self
.
args
,
params
)
# We should initialize the learning rate scheduler immediately after
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
# building the optimizer, so that the initial learning rate is set.
self
.
_lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
_lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
...
@@ -305,8 +301,7 @@ class Trainer(object):
...
@@ -305,8 +301,7 @@ class Trainer(object):
"args"
:
self
.
args
,
"args"
:
self
.
args
,
"model"
:
self
.
model
.
state_dict
(),
"model"
:
self
.
model
.
state_dict
(),
"loss"
:
(
"loss"
:
(
self
.
loss
.
state_dict
()
self
.
loss
.
state_dict
()
if
utils
.
has_parameters
(
self
.
loss
)
else
None
if
utils
.
has_parameters
(
self
.
loss
)
else
None
),
),
"optimizer_history"
:
(
self
.
_optim_history
or
[])
"optimizer_history"
:
(
self
.
_optim_history
or
[])
+
[
+
[
...
@@ -321,7 +316,7 @@ class Trainer(object):
...
@@ -321,7 +316,7 @@ class Trainer(object):
"extra_state"
:
{
"extra_state"
:
{
"metrics"
:
metrics
.
state_dict
(),
"metrics"
:
metrics
.
state_dict
(),
"previous_training_time"
:
self
.
cumulative_training_time
(),
"previous_training_time"
:
self
.
cumulative_training_time
(),
}
}
,
}
}
if
not
self
.
args
.
no_save_optimizer_state
:
if
not
self
.
args
.
no_save_optimizer_state
:
state_dict
[
"last_optimizer_state"
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
"last_optimizer_state"
]
=
self
.
optimizer
.
state_dict
()
...
@@ -375,7 +370,7 @@ class Trainer(object):
...
@@ -375,7 +370,7 @@ class Trainer(object):
state
=
None
state
=
None
if
is_master
:
if
is_master
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
filename
,
filename
,
)
)
if
is_distributed
:
if
is_distributed
:
logger
.
info
(
"Broadcast checkpoint from rank_0"
)
logger
.
info
(
"Broadcast checkpoint from rank_0"
)
...
@@ -392,19 +387,28 @@ class Trainer(object):
...
@@ -392,19 +387,28 @@ class Trainer(object):
try
:
try
:
if
self
.
args
.
load_from_ema
:
if
self
.
args
.
load_from_ema
:
logger
.
info
(
"loading ema state to model"
)
logger
.
info
(
"loading ema state to model"
)
self
.
model
.
load_state_dict
(
errors
=
self
.
model
.
load_state_dict
(
ema_state
[
"params"
],
strict
=
False
,
model_args
=
self
.
args
ema_state
[
"params"
],
strict
=
False
,
model_args
=
self
.
args
)
)
else
:
else
:
self
.
model
.
load_state_dict
(
errors
=
self
.
model
.
load_state_dict
(
state
[
"model"
],
strict
=
False
,
model_args
=
self
.
args
state
[
"model"
],
strict
=
False
,
model_args
=
self
.
args
)
)
# save memory for later steps
# save memory for later steps
del
state
[
"model"
]
del
state
[
"model"
]
if
utils
.
has_parameters
(
self
.
get_loss
()):
self
.
get_loss
().
load_state_dict
(
if
errors
.
missing_keys
:
state
[
"loss"
],
strict
=
True
logger
.
warning
(
"Error in loading model state, missing_keys "
+
str
(
errors
.
missing_keys
)
)
if
errors
.
unexpected_keys
:
logger
.
warning
(
"Error in loading model state, unexpected_keys "
+
str
(
errors
.
unexpected_keys
)
)
)
if
utils
.
has_parameters
(
self
.
get_loss
()):
self
.
get_loss
().
load_state_dict
(
state
[
"loss"
],
strict
=
True
)
del
state
[
"loss"
]
del
state
[
"loss"
]
except
Exception
:
except
Exception
:
...
@@ -413,13 +417,21 @@ class Trainer(object):
...
@@ -413,13 +417,21 @@ class Trainer(object):
"please ensure that the architectures match."
.
format
(
filename
)
"please ensure that the architectures match."
.
format
(
filename
)
)
)
extra_state
=
state
[
"extra_state"
]
if
"extra_state"
in
state
else
None
extra_state
=
state
[
"extra_state"
]
if
"extra_state"
in
state
else
None
self
.
_optim_history
=
state
[
"optimizer_history"
]
if
"optimizer_history"
in
state
else
None
self
.
_optim_history
=
(
state
[
"optimizer_history"
]
if
"optimizer_history"
in
state
else
None
if
ema_state
is
not
None
and
self
.
ema
is
not
None
and
not
self
.
args
.
load_from_ema
:
)
if
(
ema_state
is
not
None
and
self
.
ema
is
not
None
and
not
self
.
args
.
load_from_ema
):
logger
.
info
(
f
"Loading EMA state..."
)
logger
.
info
(
f
"Loading EMA state..."
)
self
.
ema
.
load_state_dict
(
ema_state
)
self
.
ema
.
load_state_dict
(
ema_state
)
elif
self
.
ema
is
not
None
:
elif
self
.
ema
is
not
None
:
logger
.
info
(
f
"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
logger
.
info
(
f
"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
_model
,
decay
=
self
.
ema
.
decay
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
_model
,
decay
=
self
.
ema
.
decay
)
if
last_optim_state
is
not
None
and
not
reset_optimizer
:
if
last_optim_state
is
not
None
and
not
reset_optimizer
:
...
@@ -437,7 +449,7 @@ class Trainer(object):
...
@@ -437,7 +449,7 @@ class Trainer(object):
if
not
reset_lr_scheduler
:
if
not
reset_lr_scheduler
:
self
.
lr_scheduler
.
load_state_dict
(
last_optim
[
"lr_scheduler_state"
])
self
.
lr_scheduler
.
load_state_dict
(
last_optim
[
"lr_scheduler_state"
])
self
.
optimizer
.
load_state_dict
(
last_optim_state
,
optimizer_overrides
)
self
.
optimizer
.
load_state_dict
(
last_optim_state
,
optimizer_overrides
)
self
.
set_num_updates
(
last_optim
[
"num_updates"
])
self
.
set_num_updates
(
last_optim
[
"num_updates"
])
...
@@ -452,7 +464,10 @@ class Trainer(object):
...
@@ -452,7 +464,10 @@ class Trainer(object):
# self.lr_step(epoch)
# self.lr_step(epoch)
if
itr_state
.
get
(
"version"
,
1
)
>=
2
and
itr_state
[
"iterations_in_epoch"
]
==
0
:
if
(
itr_state
.
get
(
"version"
,
1
)
>=
2
and
itr_state
[
"iterations_in_epoch"
]
==
0
):
# reset meters at start of epoch
# reset meters at start of epoch
reset_meters
=
True
reset_meters
=
True
...
@@ -511,10 +526,12 @@ class Trainer(object):
...
@@ -511,10 +526,12 @@ class Trainer(object):
def
init_total_train_steps
(
self
,
epoch_itr
):
def
init_total_train_steps
(
self
,
epoch_itr
):
if
self
.
args
.
max_epoch
>
0
:
if
self
.
args
.
max_epoch
>
0
:
self
.
_total_train_steps
=
(
len
(
epoch_itr
)
+
1
)
//
self
.
args
.
update_freq
[
0
]
*
self
.
args
.
max_epoch
self
.
_total_train_steps
=
(
(
len
(
epoch_itr
)
+
1
)
//
self
.
args
.
update_freq
[
0
]
*
self
.
args
.
max_epoch
)
else
:
else
:
self
.
_total_train_steps
=
self
.
args
.
max_update
self
.
_total_train_steps
=
self
.
args
.
max_update
def
get_valid_iterator
(
def
get_valid_iterator
(
self
,
self
,
subset
,
subset
,
...
@@ -589,7 +606,9 @@ class Trainer(object):
...
@@ -589,7 +606,9 @@ class Trainer(object):
try
:
try
:
with
maybe_no_sync
():
with
maybe_no_sync
():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers.
# use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
(),
self
.
data_parallel_rank
):
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
(),
self
.
data_parallel_rank
):
# forward and backward
# forward and backward
loss
,
sample_size_i
,
logging_output
=
self
.
task
.
train_step
(
loss
,
sample_size_i
,
logging_output
=
self
.
task
.
train_step
(
sample
=
sample
,
sample
=
sample
,
...
@@ -601,7 +620,9 @@ class Trainer(object):
...
@@ -601,7 +620,9 @@ class Trainer(object):
)
)
del
loss
del
loss
if
self
.
args
.
per_sample_clip_norm
>
0
:
if
self
.
args
.
per_sample_clip_norm
>
0
:
self
.
optimizer
.
per_sample_clip_grad_norm
(
self
.
args
.
per_sample_clip_norm
)
self
.
optimizer
.
per_sample_clip_grad_norm
(
self
.
args
.
per_sample_clip_norm
)
logging_outputs
.
append
(
logging_output
)
logging_outputs
.
append
(
logging_output
)
sample_size
+=
sample_size_i
sample_size
+=
sample_size_i
...
@@ -647,7 +668,12 @@ class Trainer(object):
...
@@ -647,7 +668,12 @@ class Trainer(object):
ooms
,
ooms
,
total_train_time
,
total_train_time
,
)
=
self
.
_aggregate_logging_outputs
(
)
=
self
.
_aggregate_logging_outputs
(
logging_outputs
,
sample_size
,
ooms
,
train_time
,
ignore
=
is_dummy_batch
,
is_train
=
True
,
logging_outputs
,
sample_size
,
ooms
,
train_time
,
ignore
=
is_dummy_batch
,
is_train
=
True
,
)
)
self
.
_cumulative_training_time
=
(
self
.
_cumulative_training_time
=
(
total_train_time
/
self
.
data_parallel_world_size
total_train_time
/
self
.
data_parallel_world_size
...
@@ -670,11 +696,7 @@ class Trainer(object):
...
@@ -670,11 +696,7 @@ class Trainer(object):
# (Debugging note: Some optimizers perform this scaling on the
# (Debugging note: Some optimizers perform this scaling on the
# fly, so inspecting model.parameters() or optimizer.params may
# fly, so inspecting model.parameters() or optimizer.params may
# still show the original, unscaled gradients.)
# still show the original, unscaled gradients.)
numer
=
(
numer
=
self
.
data_parallel_world_size
if
self
.
_sync_stats
()
else
1
self
.
data_parallel_world_size
if
self
.
_sync_stats
()
else
1
)
self
.
optimizer
.
multiply_grads
(
numer
/
(
sample_size
or
1.0
))
self
.
optimizer
.
multiply_grads
(
numer
/
(
sample_size
or
1.0
))
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
...
@@ -695,7 +717,9 @@ class Trainer(object):
...
@@ -695,7 +717,9 @@ class Trainer(object):
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
(),
-
1
):
with
utils
.
torch_seed
(
self
.
args
.
seed
,
self
.
get_num_updates
(),
-
1
):
# take an optimization step
# take an optimization step
self
.
task
.
optimizer_step
(
self
.
task
.
optimizer_step
(
self
.
optimizer
,
model
=
self
.
model
,
update_num
=
self
.
get_num_updates
()
self
.
optimizer
,
model
=
self
.
model
,
update_num
=
self
.
get_num_updates
(),
)
)
if
self
.
ema
is
not
None
:
if
self
.
ema
is
not
None
:
with
torch
.
autograd
.
profiler
.
record_function
(
"ema"
):
with
torch
.
autograd
.
profiler
.
record_function
(
"ema"
):
...
@@ -719,7 +743,9 @@ class Trainer(object):
...
@@ -719,7 +743,9 @@ class Trainer(object):
raise
raise
except
OverflowError
as
e
:
except
OverflowError
as
e
:
overflow
=
True
overflow
=
True
logger
.
info
(
f
"NOTE: gradient overflow detected, ignoring gradient,
{
str
(
e
)
}
"
)
logger
.
info
(
f
"NOTE: gradient overflow detected, ignoring gradient,
{
str
(
e
)
}
"
)
grad_norm
=
torch
.
tensor
(
0.0
).
cuda
()
grad_norm
=
torch
.
tensor
(
0.0
).
cuda
()
self
.
zero_grad
()
self
.
zero_grad
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
...
@@ -737,13 +763,13 @@ class Trainer(object):
...
@@ -737,13 +763,13 @@ class Trainer(object):
gb_used
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
/
1024
/
1024
gb_used
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
/
1024
/
1024
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
gb_free
=
self
.
cuda_env
.
total_memory_in_GB
-
gb_used
gb_free
=
self
.
cuda_env
.
total_memory_in_GB
-
gb_used
metrics
.
log_scalar
(
metrics
.
log_scalar
(
"gb_free"
,
gb_free
,
priority
=
1500
,
round
=
1
,
weight
=
0
)
"gb_free"
,
gb_free
,
priority
=
1500
,
round
=
1
,
weight
=
0
)
# log stats
# log stats
logging_output
=
self
.
_reduce_and_log_stats
(
logging_output
=
self
.
_reduce_and_log_stats
(
logging_outputs
,
sample_size
,
grad_norm
,
logging_outputs
,
sample_size
,
grad_norm
,
)
)
# clear CUDA cache to reduce memory fragmentation
# clear CUDA cache to reduce memory fragmentation
...
@@ -865,9 +891,7 @@ class Trainer(object):
...
@@ -865,9 +891,7 @@ class Trainer(object):
metrics
.
log_scalar
(
"num_updates"
,
self
.
_num_updates
,
weight
=
0
,
priority
=
200
)
metrics
.
log_scalar
(
"num_updates"
,
self
.
_num_updates
,
weight
=
0
,
priority
=
200
)
def
clip_grad_norm
(
self
,
clip_norm
):
def
clip_grad_norm
(
self
,
clip_norm
):
return
self
.
optimizer
.
clip_grad_norm
(
return
self
.
optimizer
.
clip_grad_norm
(
clip_norm
)
clip_norm
)
def
cumulative_training_time
(
self
):
def
cumulative_training_time
(
self
):
if
self
.
_cumulative_training_time
is
None
:
if
self
.
_cumulative_training_time
is
None
:
...
@@ -908,7 +932,7 @@ class Trainer(object):
...
@@ -908,7 +932,7 @@ class Trainer(object):
return
t
.
to
(
dtype
=
torch
.
bfloat16
)
return
t
.
to
(
dtype
=
torch
.
bfloat16
)
return
t
return
t
# Please manually convert data type by yourself.
# Please manually convert data type by yourself.
# if self.args.fp16:
# if self.args.fp16:
# sample = utils.apply_to_sample(apply_half, sample)
# sample = utils.apply_to_sample(apply_half, sample)
...
@@ -942,7 +966,9 @@ class Trainer(object):
...
@@ -942,7 +966,9 @@ class Trainer(object):
ignore
=
False
,
ignore
=
False
,
is_train
=
False
,
is_train
=
False
,
):
):
if
self
.
task
.
__class__
.
logging_outputs_can_be_summed
(
self
.
get_loss
(),
is_train
=
is_train
):
if
self
.
task
.
__class__
.
logging_outputs_can_be_summed
(
self
.
get_loss
(),
is_train
=
is_train
):
return
self
.
_fast_stat_sync_sum
(
return
self
.
_fast_stat_sync_sum
(
logging_outputs
,
*
extra_stats_to_sum
,
ignore
=
ignore
logging_outputs
,
*
extra_stats_to_sum
,
ignore
=
ignore
)
)
...
@@ -978,7 +1004,10 @@ class Trainer(object):
...
@@ -978,7 +1004,10 @@ class Trainer(object):
return
logging_outputs
,
extra_stats_to_sum
return
logging_outputs
,
extra_stats_to_sum
def
_fast_stat_sync_sum
(
def
_fast_stat_sync_sum
(
self
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]],
*
extra_stats_to_sum
,
ignore
=
False
,
self
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]],
*
extra_stats_to_sum
,
ignore
=
False
,
):
):
"""
"""
Sync logging outputs across workers. fast_stat_sync_sum is
Sync logging outputs across workers. fast_stat_sync_sum is
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment