Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e2af089c
Commit
e2af089c
authored
Jun 13, 2022
by
Tim Moon
Browse files
Update dist Adam test to use updated API
parent
6e412916
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
100 additions
and
134 deletions
+100
-134
tests/L0/run_optimizers/test_dist_adam.py
tests/L0/run_optimizers/test_dist_adam.py
+100
-134
No files found.
tests/L0/run_optimizers/test_dist_adam.py
View file @
e2af089c
import
argparse
import
argparse
import
os
import
random
import
random
import
sys
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
apex
import
amp
from
apex.optimizers
import
FusedAdam
from
apex.contrib.optimizers.distributed_fused_adam
import
DistributedFusedAdam
from
apex.contrib.optimizers.distributed_fused_adam
import
DistributedFusedAdam
class
TestModel
(
torch
.
nn
.
Module
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
super
(
TestModel
,
self
).
__init__
()
super
(
TestModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Sequential
(
*
[
self
.
linear
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
args
.
dim
,
args
.
dim
,
bias
=
args
.
bias
)
for
_
in
range
(
args
.
layers
)])
torch
.
nn
.
Linear
(
args
.
dim
,
args
.
dim
)
for
_
in
range
(
args
.
layers
)
])
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
y
=
0
for
l
in
self
.
linear
:
y
+=
l
(
x
)
return
y
def
setup
(
args
):
def
setup
(
args
):
## Model
ref_model
=
TestModel
(
args
).
cuda
()
dist_model
=
TestModel
(
args
).
cuda
()
# Same weights
# Construct models with same parameters
ref_model
=
TestModel
(
args
).
float
().
cuda
()
dist_model
=
TestModel
(
args
).
float
().
cuda
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
dp
,
rp
in
zip
(
dist_model
.
parameters
(),
ref_model
.
parameters
()):
for
ref_param
,
dist_param
in
zip
(
dist_model
.
parameters
(),
dp
.
data
.
copy_
(
rp
.
data
)
ref_model
.
parameters
()):
dist_param
.
data
.
copy_
(
ref_param
.
data
)
dist_model
=
dist_model
.
half
()
ref_model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
ref_model
,
device_ids
=
[
args
.
rank
],
## Optimizer
output_device
=
args
.
rank
,
# same hyperparameters
)
ref_opt_args
=
{
'lr'
:
1e-3
,
'eps'
:
1e-6
,
'weight_decay'
:
0.01
}
ref_opt
=
FusedAdam
(
ref_model
.
parameters
(),
**
ref_opt_args
)
# Construct optimizers with same hyperparameters
optim_args
=
{
'lr'
:
1e-3
,
'eps'
:
1e-6
,
'weight_decay'
:
0.01
}
dist_opt_args
=
ref_opt_args
.
copy
()
ref_optim
=
torch
.
optim
.
Adam
(
dist_opt_args
.
update
(
{
'overlap_reductions'
:
False
}
)
[
dist_opt_args
.
update
(
{
'process_group_size'
:
args
.
n_gpu
}
)
{
'params'
:
list
(
ref_model
.
parameters
())[
1
::
2
],
'lr'
:
5e-3
},
dist_opt_args
.
update
(
{
'dwu_group_size'
:
args
.
dwu_group_size
}
)
{
'params'
:
list
(
ref_model
.
parameters
())[
0
::
2
]},
dist_opt_args
.
update
(
{
'dwu_num_blocks'
:
1
}
)
],
dist_opt_args
.
update
(
{
'dwu_num_chunks'
:
1
}
)
**
optim_args
,
dist_opt
=
DistributedFusedAdam
(
dist_model
.
parameters
(),
**
dist_opt_args
)
)
dist_opt
.
set_global_scale
(
1.
)
dist_optim
=
DistributedFusedAdam
(
[
## amp-init
{
'params'
:
list
(
dist_model
.
parameters
())[
1
::
2
],
'lr'
:
5e-3
},
amp_args
=
{
'loss_scale'
:
'dynamic'
,
'opt_level'
:
'O2'
}
{
'params'
:
list
(
dist_model
.
parameters
())[
0
::
2
]},
ref_model
,
ref_opt
=
amp
.
initialize
(
ref_model
,
ref_opt
,
**
amp_args
)
],
bucket_cap_mb
=
71
/
(
4
*
1024
*
1024
),
**
optim_args
,
## DDP
)
ref_model
=
DDP
(
ref_model
,
device_ids
=
[
args
.
rank
])
with
torch
.
no_grad
():
return
ref_model
,
ref_optim
,
dist_model
,
dist_optim
for
dp
in
dist_model
.
parameters
():
torch
.
distributed
.
broadcast
(
dp
.
data
,
src
=
0
)
for
rp
in
ref_model
.
parameters
():
torch
.
distributed
.
broadcast
(
rp
.
data
,
src
=
0
)
torch
.
cuda
.
synchronize
()
torch
.
distributed
.
barrier
()
if
get_rank
()
==
0
:
print
(
f
'dist opt with
{
args
.
n_gpu
}
GPUs'
)
return
ref_model
,
ref_opt
,
dist_model
,
dist_opt
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--steps'
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
'--steps'
,
type
=
int
,
default
=
11
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
7
)
parser
.
add_argument
(
'--layers'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--layers'
,
type
=
int
,
default
=
11
)
parser
.
add_argument
(
'--bias'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
'--atol'
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1
)
parser
.
add_argument
(
'--rtol'
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
'--dwu_group_size'
,
type
=
float
,
default
=
1
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
setup_env
(
args
):
def
setup_env
(
args
):
torch
.
cuda
.
set_device
(
args
.
local_rank
)
# Initialize NCCL
local_rank
=
args
.
local_rank
if
local_rank
<
0
:
local_rank
=
int
(
os
.
getenv
(
'LOCAL_RANK'
,
0
))
torch
.
cuda
.
set_device
(
local_rank
%
torch
.
cuda
.
device_count
())
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
n_gpu
=
torch
.
distributed
.
get_world_size
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
seed
=
42
+
get_rank
()
# Initialize RNG
seed
=
42
+
args
.
rank
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
return
args
return
args
def
get_rank
():
return
torch
.
distributed
.
get_rank
()
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
args
=
setup_env
(
args
)
args
=
setup_env
(
args
)
tol_args
=
{
'atol'
:
args
.
atol
,
'rtol'
:
args
.
rtol
}
torch
.
set_printoptions
(
precision
=
16
)
torch
.
set_printoptions
(
precision
=
16
)
ref_model
,
ref_opt
,
dist_model
,
dist_opt
=
setup
(
args
)
def
assert_allclose
(
ref_x
,
dist_x
,
message
):
message
=
(
# lazy_init not called yet, initialize stash
f
'Rank
{
args
.
rank
}
:
{
message
}
\n
'
stash
=
ref_opt
.
_amp_stash
f
'Reference Adam:
{
ref_x
}
\n
'
stash
.
all_fp16_params
,
stash
.
all_fp32_from_fp16_params
=
[],
[]
f
'Distributed Adam:
{
dist_x
}
\n
'
f
'Relative error:
{
torch
.
abs
((
ref_x
-
dist_x
)
/
ref_x
)
}
\n
'
# make sure everything from _first_step_init_ is ready before training
)
# e.g. registering allreduce_hook
assert
torch
.
allclose
(
ref_x
,
dist_x
,
atol
=
args
.
atol
,
rtol
=
args
.
rtol
),
message
# so that gradients are copied/reduced when necessary
dist_opt
.
_init_everything
()
# Train model with data-parallelism and ZeRO
ref_model
,
ref_optim
,
dist_model
,
dist_optim
=
setup
(
args
)
for
i
in
range
(
args
.
steps
):
for
step
in
range
(
args
.
steps
):
x_ref
=
torch
.
randn
(
args
.
batch
,
args
.
dim
,
dtype
=
torch
.
half
).
cuda
().
requires_grad_
(
True
)
x_dist
=
x_ref
.
clone
().
detach
().
requires_grad_
(
True
)
# Synthetic data
x
=
torch
.
randn
(
args
.
batch
,
args
.
dim
).
cuda
()
if
get_rank
()
==
0
:
dy
=
torch
.
randn_like
(
x
).
cuda
()
print
(
f
'[
{
i
}
] Checking input'
)
#print("x_ref:", x_ref.flatten()[:10])
# Reference implementation
#print("x_dist:", x_dist.flatten()[:10])
ref_optim
.
zero_grad
()
assert
(
torch
.
allclose
(
x_ref
,
x_dist
,
**
tol_args
))
x_ref
=
x
.
detach
().
clone
().
requires_grad_
(
True
)
y_ref
=
ref_model
(
x_ref
)
y_ref
.
backward
(
dy
)
ref_optim
.
step
()
y_ref
=
ref_model
(
x_ref
).
half
()
# Distributed implementation
dist_optim
.
zero_grad
()
x_dist
=
x
.
detach
().
clone
().
requires_grad_
(
True
)
y_dist
=
dist_model
(
x_dist
)
y_dist
=
dist_model
(
x_dist
)
if
get_rank
()
==
0
:
print
(
f
'[
{
i
}
] Checking output'
)
#print("y_ref:", y_ref.flatten()[:10])
#print("y_dist:", y_dist.flatten()[:10])
assert
(
torch
.
allclose
(
y_ref
,
y_dist
,
**
tol_args
))
dy
=
torch
.
randn_like
(
y_ref
)
y_ref
.
backward
(
dy
)
y_dist
.
backward
(
dy
)
y_dist
.
backward
(
dy
)
dist_optim
.
step
()
if
get_rank
()
==
0
:
# Check values
print
(
f
'[
{
i
}
] Checking gradients'
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
assert
(
torch
.
allclose
(
x_ref
.
grad
,
x_dist
.
grad
,
**
tol_args
))
# gradient all-reduce within distributed optimizer
dist_opt
.
complete_reductions
()
if
get_rank
()
==
0
:
print
(
f
'[
{
i
}
] Stepping'
)
ref_opt
.
step
()
dist_opt
.
step
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print
(
'Checking new weights'
)
assert_allclose
(
if
get_rank
()
==
0
:
y_ref
,
print
(
"ref param:"
,
ref_model
.
module
.
linear
[
0
].
weight
)
y_dist
,
print
(
"dist param:"
,
dist_model
.
linear
[
0
].
weight
)
f
'inconsistent output in step
{
step
}
'
,
)
for
i
,
(
rp
,
dp
)
in
enumerate
(
zip
(
ref_model
.
parameters
(),
dist_model
.
parameters
())):
assert_allclose
(
if
not
torch
.
allclose
(
rp
,
dp
,
**
tol_args
):
x_ref
.
grad
,
if
get_rank
()
==
0
:
x_dist
.
grad
,
print
(
f
'Rank:
{
get_rank
()
}
, Param:
{
i
}
'
)
f
'inconsistent input grad in step
{
step
}
'
,
print
(
f
'ref:
{
rp
.
sum
().
item
()
}
, dist:
{
dp
.
sum
().
item
()
}
'
)
)
print
(
rp
)
for
i
,
(
ref_param
,
dist_param
)
in
enumerate
(
zip
(
ref_model
.
parameters
(),
print
(
dp
)
dist_model
.
parameters
())):
assert_allclose
(
print
(
torch
.
abs
(
rp
-
dp
)
>
tol_args
[
'atol'
])
ref_param
,
sys
.
exit
(
0
)
dist_param
,
f
'inconsistent param
{
i
}
in step
{
step
}
'
,
# zero grads
)
for
rp
,
dp
in
zip
(
ref_model
.
parameters
(),
dist_model
.
parameters
()):
rp
.
grad
=
None
dp
.
grad
=
None
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment