Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
adbb02f6
Commit
adbb02f6
authored
Apr 11, 2023
by
Zeqiang Lai
Committed by
zhe chen
Apr 11, 2023
Browse files
Support ema for main_deepspeed.py, fix torch.distribute.launch (#88)
parent
1c6361d8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
134 additions
and
9 deletions
+134
-9
classification/ema_deepspeed.py
classification/ema_deepspeed.py
+99
-0
classification/main.py
classification/main.py
+1
-1
classification/main_deepspeed.py
classification/main_deepspeed.py
+34
-8
No files found.
classification/ema_deepspeed.py
0 → 100644
View file @
adbb02f6
import
torch
import
torch.nn
as
nn
import
deepspeed
from
deepspeed.runtime.zero
import
GatheredParameters
from
contextlib
import
contextmanager
class
EMADeepspeed
(
nn
.
Module
):
""" migrated from https://github.com/microsoft/DeepSpeed/issues/2056
"""
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_updates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
'Decay must be between 0 and 1'
)
self
.
m_name2s_name
=
{}
self
.
decay
=
decay
self
.
num_updates
=
0
if
use_num_updates
else
-
1
with
GatheredParameters
(
model
.
parameters
(),
fwd_module
=
self
):
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
# remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
'.'
,
''
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
# remove as '.'-character is not allowed in buffers
self
.
collected_params
=
[]
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,
(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
shadow_params
=
dict
(
self
.
named_buffers
())
with
torch
.
no_grad
():
with
GatheredParameters
(
model
.
parameters
()):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
m_param
=
dict
(
model
.
named_parameters
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
shadow_params
=
dict
(
self
.
named_buffers
())
with
GatheredParameters
(
model
.
parameters
(),
modifier_rank
=
0
):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
m_param
=
dict
(
model
.
named_parameters
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
model
):
"""
Save the current parameters for restoring later.
Args:
model: A model that parameters will be stored
"""
with
GatheredParameters
(
model
.
parameters
()):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
parameters
=
model
.
parameters
()
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
model
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
model: A model that to restore its parameters.
"""
with
GatheredParameters
(
model
.
parameters
(),
modifier_rank
=
0
):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
parameters
=
model
.
parameters
()
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
@
contextmanager
def
activate
(
self
,
model
):
try
:
self
.
store
(
model
)
self
.
copy_to
(
model
)
yield
finally
:
self
.
restore
(
model
)
classification/main.py
View file @
adbb02f6
...
...
@@ -131,7 +131,7 @@ def parse_option():
help
=
"whether to use ZeroRedundancyOptimizer (ZeRO) to save memory"
)
# distributed training
parser
.
add_argument
(
"--local
_
rank"
,
parser
.
add_argument
(
"--local
-
rank"
,
type
=
int
,
required
=
True
,
help
=
'local rank for DistributedDataParallel'
)
...
...
classification/main_deepspeed.py
View file @
adbb02f6
...
...
@@ -27,7 +27,7 @@ from optimizer import set_weight_decay_and_lr
from
logger
import
create_logger
from
utils
import
load_pretrained
,
reduce_tensor
,
MyAverageMeter
from
ddp_hooks
import
fp16_compress_hook
from
ema_deepspeed
import
EMADeepspeed
def
parse_option
():
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -57,7 +57,7 @@ def parse_option():
parser
.
add_argument
(
'--accumulation-steps'
,
type
=
int
,
default
=
1
,
help
=
"gradient accumulation steps"
)
# distributed training
parser
.
add_argument
(
"--local
_
rank"
,
type
=
int
,
required
=
True
,
help
=
'local rank for DistributedDataParallel'
)
parser
.
add_argument
(
"--local
-
rank"
,
type
=
int
,
required
=
True
,
help
=
'local rank for DistributedDataParallel'
)
parser
.
add_argument
(
'--disable-grad-scalar'
,
action
=
'store_true'
,
help
=
'disable Grad Scalar'
)
args
,
unparsed
=
parser
.
parse_known_args
()
...
...
@@ -211,7 +211,7 @@ def throughput(data_loader, model, logger):
return
def
train_epoch
(
config
,
model
,
criterion
,
data_loader
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
):
def
train_epoch
(
config
,
model
,
criterion
,
data_loader
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
,
model_ema
=
None
):
model
.
train
()
num_steps
=
len
(
data_loader
)
...
...
@@ -237,6 +237,9 @@ def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_f
model
.
backward
(
loss
)
model
.
step
()
if
model_ema
is
not
None
:
model_ema
(
model
)
if
(
idx
+
1
)
%
config
.
TRAIN
.
ACCUMULATION_STEPS
==
0
:
lr_scheduler
.
step_update
(
epoch
*
num_steps
+
idx
)
...
...
@@ -348,9 +351,14 @@ def train(config, ds_config):
lr_scheduler
=
build_scheduler
(
config
,
optimizer
,
len
(
data_loader_train
))
criterion
=
build_criterion
(
config
)
model_ema
=
None
if
config
.
TRAIN
.
EMA
.
ENABLE
:
model_ema
=
EMADeepspeed
(
model
,
config
.
TRAIN
.
EMA
.
DECAY
)
# -------------- resume ---------------- #
max_accuracy
=
0.0
max_accuracy_ema
=
0.0
client_state
=
{}
if
config
.
MODEL
.
RESUME
==
''
and
config
.
TRAIN
.
AUTO_RESUME
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
config
.
OUTPUT
,
'latest'
)):
...
...
@@ -367,6 +375,10 @@ def train(config, ds_config):
logger
.
info
(
f
'client_state=
{
client_state
.
keys
()
}
'
)
lr_scheduler
.
load_state_dict
(
client_state
[
'custom_lr_scheduler'
])
max_accuracy
=
client_state
[
'max_accuracy'
]
if
model_ema
is
not
None
:
max_accuracy_ema
=
client_state
.
get
(
'max_accuracy_ema'
,
0.0
)
model_ema
.
load_state_dict
((
client_state
[
'model_ema'
]))
# -------------- training ---------------- #
...
...
@@ -378,9 +390,11 @@ def train(config, ds_config):
log_model_statistic
(
model_without_ddp
)
start_time
=
time
.
time
()
for
epoch
in
range
(
client_state
.
get
(
'epoch'
,
config
.
TRAIN
.
START_EPOCH
),
config
.
TRAIN
.
EPOCHS
):
start_epoch
=
client_state
[
'epoch'
]
+
1
if
'epoch'
in
client_state
else
config
.
TRAIN
.
START_EPOCH
for
epoch
in
range
(
start_epoch
,
config
.
TRAIN
.
EPOCHS
):
data_loader_train
.
sampler
.
set_epoch
(
epoch
)
train_epoch
(
config
,
model
,
criterion
,
data_loader_train
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
)
train_epoch
(
config
,
model
,
criterion
,
data_loader_train
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
,
model_ema
=
model_ema
)
if
epoch
%
config
.
SAVE_FREQ
==
0
or
epoch
==
config
.
TRAIN
.
EPOCHS
-
1
:
model
.
save_checkpoint
(
...
...
@@ -390,13 +404,16 @@ def train(config, ds_config):
'custom_lr_scheduler'
:
lr_scheduler
.
state_dict
(),
'max_accuracy'
:
max_accuracy
,
'epoch'
:
epoch
,
'config'
:
config
'config'
:
config
,
'max_accuracy_ema'
:
max_accuracy_ema
if
model_ema
is
not
None
else
0.0
,
'model_ema'
:
model_ema
.
state_dict
()
if
model_ema
is
not
None
else
None
,
}
)
if
epoch
%
config
.
EVAL_FREQ
==
0
:
acc1
,
_
,
_
=
eval_epoch
(
config
,
data_loader_val
,
model
,
epoch
)
logger
.
info
(
f
"Accuracy of the network on the
{
len
(
dataset_val
)
}
test images:
{
acc1
:.
1
f
}
%"
)
if
acc1
>
max_accuracy
:
model
.
save_checkpoint
(
save_dir
=
config
.
OUTPUT
,
...
...
@@ -405,13 +422,22 @@ def train(config, ds_config):
'custom_lr_scheduler'
:
lr_scheduler
.
state_dict
(),
'max_accuracy'
:
max_accuracy
,
'epoch'
:
epoch
,
'config'
:
config
'config'
:
config
,
'max_accuracy_ema'
:
max_accuracy_ema
if
model_ema
is
not
None
else
0.0
,
'model_ema'
:
model_ema
.
state_dict
()
if
model_ema
is
not
None
else
None
,
}
)
max_accuracy
=
max
(
max_accuracy
,
acc1
)
logger
.
info
(
f
'Max accuracy:
{
max_accuracy
:.
2
f
}
%'
)
if
model_ema
is
not
None
:
with
model_ema
.
activate
(
model
):
acc1_ema
,
_
,
_
=
eval_epoch
(
config
,
data_loader_val
,
model
,
epoch
)
logger
.
info
(
f
"[EMA] Accuracy of the network on the
{
len
(
dataset_val
)
}
test images:
{
acc1_ema
:.
1
f
}
%"
)
max_accuracy_ema
=
max
(
max_accuracy_ema
,
acc1_ema
)
logger
.
info
(
f
'[EMA] Max accuracy:
{
max_accuracy_ema
:.
2
f
}
%'
)
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
logger
.
info
(
'Training time {}'
.
format
(
total_time_str
))
...
...
@@ -453,7 +479,7 @@ if __name__ == '__main__':
args
,
config
=
parse_option
()
# init distributed env
if
'SLURM_PROCID'
in
os
.
environ
:
if
'SLURM_PROCID'
in
os
.
environ
and
int
(
os
.
environ
[
'SLURM_TASKS_PER_NODE'
])
!=
1
:
print
(
"
\n
Dist init: SLURM"
)
rank
=
int
(
os
.
environ
[
'SLURM_PROCID'
])
gpu
=
rank
%
torch
.
cuda
.
device_count
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment