Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
adbb02f6
"pcdet/vscode:/vscode.git/clone" did not exist on "c1d93158891a044e00c7ff0d41873d89eea20fa9"
Commit
adbb02f6
authored
Apr 11, 2023
by
Zeqiang Lai
Committed by
zhe chen
Apr 11, 2023
Browse files
Support ema for main_deepspeed.py, fix torch.distribute.launch (#88)
parent
1c6361d8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
134 additions
and
9 deletions
+134
-9
classification/ema_deepspeed.py
classification/ema_deepspeed.py
+99
-0
classification/main.py
classification/main.py
+1
-1
classification/main_deepspeed.py
classification/main_deepspeed.py
+34
-8
No files found.
classification/ema_deepspeed.py
0 → 100644
View file @
adbb02f6
import
torch
import
torch.nn
as
nn
import
deepspeed
from
deepspeed.runtime.zero
import
GatheredParameters
from
contextlib
import
contextmanager
class
EMADeepspeed
(
nn
.
Module
):
""" migrated from https://github.com/microsoft/DeepSpeed/issues/2056
"""
def
__init__
(
self
,
model
,
decay
=
0.9999
,
use_num_updates
=
True
):
super
().
__init__
()
if
decay
<
0.0
or
decay
>
1.0
:
raise
ValueError
(
'Decay must be between 0 and 1'
)
self
.
m_name2s_name
=
{}
self
.
decay
=
decay
self
.
num_updates
=
0
if
use_num_updates
else
-
1
with
GatheredParameters
(
model
.
parameters
(),
fwd_module
=
self
):
for
name
,
p
in
model
.
named_parameters
():
if
p
.
requires_grad
:
# remove as '.'-character is not allowed in buffers
s_name
=
name
.
replace
(
'.'
,
''
)
self
.
m_name2s_name
.
update
({
name
:
s_name
})
self
.
register_buffer
(
s_name
,
p
.
clone
().
detach
().
data
)
# remove as '.'-character is not allowed in buffers
self
.
collected_params
=
[]
def
forward
(
self
,
model
):
decay
=
self
.
decay
if
self
.
num_updates
>=
0
:
self
.
num_updates
+=
1
decay
=
min
(
self
.
decay
,
(
1
+
self
.
num_updates
)
/
(
10
+
self
.
num_updates
))
one_minus_decay
=
1.0
-
decay
shadow_params
=
dict
(
self
.
named_buffers
())
with
torch
.
no_grad
():
with
GatheredParameters
(
model
.
parameters
()):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
m_param
=
dict
(
model
.
named_parameters
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
sname
=
self
.
m_name2s_name
[
key
]
shadow_params
[
sname
]
=
shadow_params
[
sname
].
type_as
(
m_param
[
key
])
shadow_params
[
sname
].
sub_
(
one_minus_decay
*
(
shadow_params
[
sname
]
-
m_param
[
key
]))
else
:
assert
not
key
in
self
.
m_name2s_name
def
copy_to
(
self
,
model
):
shadow_params
=
dict
(
self
.
named_buffers
())
with
GatheredParameters
(
model
.
parameters
(),
modifier_rank
=
0
):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
m_param
=
dict
(
model
.
named_parameters
())
for
key
in
m_param
:
if
m_param
[
key
].
requires_grad
:
m_param
[
key
].
data
.
copy_
(
shadow_params
[
self
.
m_name2s_name
[
key
]].
data
)
else
:
assert
not
key
in
self
.
m_name2s_name
def
store
(
self
,
model
):
"""
Save the current parameters for restoring later.
Args:
model: A model that parameters will be stored
"""
with
GatheredParameters
(
model
.
parameters
()):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
parameters
=
model
.
parameters
()
self
.
collected_params
=
[
param
.
clone
()
for
param
in
parameters
]
def
restore
(
self
,
model
):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
model: A model that to restore its parameters.
"""
with
GatheredParameters
(
model
.
parameters
(),
modifier_rank
=
0
):
if
deepspeed
.
comm
.
get_rank
()
==
0
:
parameters
=
model
.
parameters
()
for
c_param
,
param
in
zip
(
self
.
collected_params
,
parameters
):
param
.
data
.
copy_
(
c_param
.
data
)
@
contextmanager
def
activate
(
self
,
model
):
try
:
self
.
store
(
model
)
self
.
copy_to
(
model
)
yield
finally
:
self
.
restore
(
model
)
classification/main.py
View file @
adbb02f6
...
...
@@ -131,7 +131,7 @@ def parse_option():
help
=
"whether to use ZeroRedundancyOptimizer (ZeRO) to save memory"
)
# distributed training
parser
.
add_argument
(
"--local
_
rank"
,
parser
.
add_argument
(
"--local
-
rank"
,
type
=
int
,
required
=
True
,
help
=
'local rank for DistributedDataParallel'
)
...
...
classification/main_deepspeed.py
View file @
adbb02f6
...
...
@@ -27,7 +27,7 @@ from optimizer import set_weight_decay_and_lr
from
logger
import
create_logger
from
utils
import
load_pretrained
,
reduce_tensor
,
MyAverageMeter
from
ddp_hooks
import
fp16_compress_hook
from
ema_deepspeed
import
EMADeepspeed
def
parse_option
():
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -57,7 +57,7 @@ def parse_option():
parser
.
add_argument
(
'--accumulation-steps'
,
type
=
int
,
default
=
1
,
help
=
"gradient accumulation steps"
)
# distributed training
parser
.
add_argument
(
"--local
_
rank"
,
type
=
int
,
required
=
True
,
help
=
'local rank for DistributedDataParallel'
)
parser
.
add_argument
(
"--local
-
rank"
,
type
=
int
,
required
=
True
,
help
=
'local rank for DistributedDataParallel'
)
parser
.
add_argument
(
'--disable-grad-scalar'
,
action
=
'store_true'
,
help
=
'disable Grad Scalar'
)
args
,
unparsed
=
parser
.
parse_known_args
()
...
...
@@ -211,7 +211,7 @@ def throughput(data_loader, model, logger):
return
def
train_epoch
(
config
,
model
,
criterion
,
data_loader
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
):
def
train_epoch
(
config
,
model
,
criterion
,
data_loader
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
,
model_ema
=
None
):
model
.
train
()
num_steps
=
len
(
data_loader
)
...
...
@@ -237,6 +237,9 @@ def train_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_f
model
.
backward
(
loss
)
model
.
step
()
if
model_ema
is
not
None
:
model_ema
(
model
)
if
(
idx
+
1
)
%
config
.
TRAIN
.
ACCUMULATION_STEPS
==
0
:
lr_scheduler
.
step_update
(
epoch
*
num_steps
+
idx
)
...
...
@@ -348,9 +351,14 @@ def train(config, ds_config):
lr_scheduler
=
build_scheduler
(
config
,
optimizer
,
len
(
data_loader_train
))
criterion
=
build_criterion
(
config
)
model_ema
=
None
if
config
.
TRAIN
.
EMA
.
ENABLE
:
model_ema
=
EMADeepspeed
(
model
,
config
.
TRAIN
.
EMA
.
DECAY
)
# -------------- resume ---------------- #
max_accuracy
=
0.0
max_accuracy_ema
=
0.0
client_state
=
{}
if
config
.
MODEL
.
RESUME
==
''
and
config
.
TRAIN
.
AUTO_RESUME
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
config
.
OUTPUT
,
'latest'
)):
...
...
@@ -368,6 +376,10 @@ def train(config, ds_config):
lr_scheduler
.
load_state_dict
(
client_state
[
'custom_lr_scheduler'
])
max_accuracy
=
client_state
[
'max_accuracy'
]
if
model_ema
is
not
None
:
max_accuracy_ema
=
client_state
.
get
(
'max_accuracy_ema'
,
0.0
)
model_ema
.
load_state_dict
((
client_state
[
'model_ema'
]))
# -------------- training ---------------- #
logger
.
info
(
f
"Creating model:
{
config
.
MODEL
.
TYPE
}
/
{
config
.
MODEL
.
NAME
}
"
)
...
...
@@ -378,9 +390,11 @@ def train(config, ds_config):
log_model_statistic
(
model_without_ddp
)
start_time
=
time
.
time
()
for
epoch
in
range
(
client_state
.
get
(
'epoch'
,
config
.
TRAIN
.
START_EPOCH
),
config
.
TRAIN
.
EPOCHS
):
start_epoch
=
client_state
[
'epoch'
]
+
1
if
'epoch'
in
client_state
else
config
.
TRAIN
.
START_EPOCH
for
epoch
in
range
(
start_epoch
,
config
.
TRAIN
.
EPOCHS
):
data_loader_train
.
sampler
.
set_epoch
(
epoch
)
train_epoch
(
config
,
model
,
criterion
,
data_loader_train
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
)
train_epoch
(
config
,
model
,
criterion
,
data_loader_train
,
optimizer
,
epoch
,
mixup_fn
,
lr_scheduler
,
model_ema
=
model_ema
)
if
epoch
%
config
.
SAVE_FREQ
==
0
or
epoch
==
config
.
TRAIN
.
EPOCHS
-
1
:
model
.
save_checkpoint
(
...
...
@@ -390,13 +404,16 @@ def train(config, ds_config):
'custom_lr_scheduler'
:
lr_scheduler
.
state_dict
(),
'max_accuracy'
:
max_accuracy
,
'epoch'
:
epoch
,
'config'
:
config
'config'
:
config
,
'max_accuracy_ema'
:
max_accuracy_ema
if
model_ema
is
not
None
else
0.0
,
'model_ema'
:
model_ema
.
state_dict
()
if
model_ema
is
not
None
else
None
,
}
)
if
epoch
%
config
.
EVAL_FREQ
==
0
:
acc1
,
_
,
_
=
eval_epoch
(
config
,
data_loader_val
,
model
,
epoch
)
logger
.
info
(
f
"Accuracy of the network on the
{
len
(
dataset_val
)
}
test images:
{
acc1
:.
1
f
}
%"
)
if
acc1
>
max_accuracy
:
model
.
save_checkpoint
(
save_dir
=
config
.
OUTPUT
,
...
...
@@ -405,13 +422,22 @@ def train(config, ds_config):
'custom_lr_scheduler'
:
lr_scheduler
.
state_dict
(),
'max_accuracy'
:
max_accuracy
,
'epoch'
:
epoch
,
'config'
:
config
'config'
:
config
,
'max_accuracy_ema'
:
max_accuracy_ema
if
model_ema
is
not
None
else
0.0
,
'model_ema'
:
model_ema
.
state_dict
()
if
model_ema
is
not
None
else
None
,
}
)
max_accuracy
=
max
(
max_accuracy
,
acc1
)
logger
.
info
(
f
'Max accuracy:
{
max_accuracy
:.
2
f
}
%'
)
if
model_ema
is
not
None
:
with
model_ema
.
activate
(
model
):
acc1_ema
,
_
,
_
=
eval_epoch
(
config
,
data_loader_val
,
model
,
epoch
)
logger
.
info
(
f
"[EMA] Accuracy of the network on the
{
len
(
dataset_val
)
}
test images:
{
acc1_ema
:.
1
f
}
%"
)
max_accuracy_ema
=
max
(
max_accuracy_ema
,
acc1_ema
)
logger
.
info
(
f
'[EMA] Max accuracy:
{
max_accuracy_ema
:.
2
f
}
%'
)
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
logger
.
info
(
'Training time {}'
.
format
(
total_time_str
))
...
...
@@ -453,7 +479,7 @@ if __name__ == '__main__':
args
,
config
=
parse_option
()
# init distributed env
if
'SLURM_PROCID'
in
os
.
environ
:
if
'SLURM_PROCID'
in
os
.
environ
and
int
(
os
.
environ
[
'SLURM_TASKS_PER_NODE'
])
!=
1
:
print
(
"
\n
Dist init: SLURM"
)
rank
=
int
(
os
.
environ
[
'SLURM_PROCID'
])
gpu
=
rank
%
torch
.
cuda
.
device_count
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment