Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
27155b85
Commit
27155b85
authored
Mar 02, 2022
by
Frank Lee
Browse files
added unit test for sharded optimizer (#293)
* added unit test for sharded optimizer * refactor for elegance
parent
e17e54e3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
178 additions
and
0 deletions
+178
-0
tests/test_zero_data_parallel/test_sharded_optim.py
tests/test_zero_data_parallel/test_sharded_optim.py
+178
-0
No files found.
tests/test_zero_data_parallel/test_sharded_optim.py
0 → 100644
View file @
27155b85
import
torch
import
colossalai
import
copy
import
pytest
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.zero
import
ShardedOptimizer
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.utils
import
free_port
from
functools
import
partial
def
check_equal
(
a
,
b
):
"""
This function checks if two tensors are equal within tolerance
"""
assert
torch
.
allclose
(
a
.
float
(),
b
.
float
(),
rtol
=
1e-4
,
atol
=
1e-3
),
f
'a =
{
a
}
, b =
{
b
}
'
def
check_completely_equal
(
a
,
b
):
"""
This function checks if two tensors are completely equal
"""
assert
torch
.
all
(
a
==
b
),
f
'a =
{
a
}
, b =
{
b
}
'
def
check_sharded_param_consistency
():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
pattern
we use these prefixes to differentiate the zero stage
oss: partition optimizer states
pg: partition gradients and optimizer states
"""
# create layers
oss_linear1
=
nn
.
Linear
(
128
,
256
)
oss_linear2
=
nn
.
Linear
(
256
,
512
)
# create model
oss_model
=
nn
.
Sequential
(
oss_linear1
,
oss_linear2
)
pg_model
=
copy
.
deepcopy
(
oss_model
)
oss_model
=
oss_model
.
cuda
().
half
()
pg_model
=
pg_model
.
cuda
().
half
()
# create optimizer
oss_optimizer
=
torch
.
optim
.
Adam
(
oss_model
.
parameters
(),
lr
=
0.001
)
pg_optimizer
=
torch
.
optim
.
Adam
(
pg_model
.
parameters
(),
lr
=
0.001
)
oss_optimizer
=
ShardedOptimizer
(
oss_optimizer
,
overlap_communication
=
True
,
initial_scale
=
1
,
clip_grad_norm
=
0.0
)
pg_optimizer
=
ShardedOptimizer
(
pg_optimizer
,
overlap_communication
=
True
,
partition_grad
=
True
,
initial_scale
=
1
,
clip_grad_norm
=
0.0
)
# create
input_data
=
torch
.
rand
(
32
,
128
).
cuda
().
half
()
# forward
oss_output
=
oss_model
(
input_data
)
pg_output
=
pg_model
(
input_data
)
check_completely_equal
(
oss_output
,
pg_output
)
# backward
oss_optimizer
.
backward
(
oss_output
.
mean
().
float
())
pg_optimizer
.
backward
(
pg_output
.
mean
().
float
())
# check grad
# as this param is small, the backward reduction
# will not be fired
oss_linear1_grad
=
oss_model
[
0
].
weight
.
grad
oss_linear2_grad
=
oss_model
[
1
].
weight
.
grad
pg_linear1_grad
=
pg_model
[
0
].
weight
.
grad
pg_linear2_grad
=
pg_model
[
1
].
weight
.
grad
check_completely_equal
(
oss_linear1_grad
,
pg_linear1_grad
)
check_completely_equal
(
oss_linear2_grad
,
pg_linear2_grad
)
# step
oss_optimizer
.
sync_grad
()
pg_optimizer
.
sync_grad
()
# step
oss_optimizer
.
step
()
pg_optimizer
.
step
()
# check updated param
check_completely_equal
(
oss_model
[
0
].
weight
,
pg_model
[
0
].
weight
)
check_completely_equal
(
oss_model
[
1
].
weight
,
pg_model
[
1
].
weight
)
def
check_sharded_optim_against_torch_ddp
():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
2. torch: use torch DDP and fp32 parameters
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
# create layer
zero_linear1
=
nn
.
Linear
(
128
,
256
)
zero_linear2
=
nn
.
Linear
(
256
,
512
)
# create model
zero_model
=
nn
.
Sequential
(
zero_linear1
,
zero_linear2
)
torch_model
=
copy
.
deepcopy
(
zero_model
)
zero_model
=
zero_model
.
cuda
().
half
()
torch_model
=
DDP
(
torch_model
.
cuda
())
# create optimizer
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
(),
lr
=
0.001
)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer
=
ShardedOptimizer
(
zero_optimizer
,
overlap_communication
=
True
,
initial_scale
=
1
,
clip_grad_norm
=
0.0
)
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
0.001
)
# create
input_data
=
torch
.
rand
(
32
,
128
).
cuda
()
# zero-dp forward
zero_output
=
zero_model
(
input_data
.
half
())
# torch-ddp forward
torch_output
=
torch_model
(
input_data
)
check_equal
(
zero_output
,
torch_output
)
# zero-dp backward
zero_optimizer
.
backward
(
zero_output
.
mean
().
float
())
# torch-ddp backward
torch_output
.
mean
().
backward
()
# check grad
zero_linear1_grad
=
zero_model
[
0
].
weight
.
grad
zero_linear2_grad
=
zero_model
[
1
].
weight
.
grad
torch_linear1_grad
=
torch_model
.
module
[
0
].
weight
.
grad
torch_linear2_grad
=
torch_model
.
module
[
1
].
weight
.
grad
check_equal
(
zero_linear1_grad
,
torch_linear1_grad
)
check_equal
(
zero_linear2_grad
,
torch_linear2_grad
)
# zero-dp step
zero_optimizer
.
sync_grad
()
zero_optimizer
.
step
()
# torch ddp step
torch_optimizer
.
step
()
# check updated param
check_equal
(
zero_model
[
0
].
weight
,
torch_model
.
module
[
0
].
weight
)
check_equal
(
zero_model
[
1
].
weight
,
torch_model
.
module
[
1
].
weight
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
check_sharded_optim_against_torch_ddp
()
check_sharded_param_consistency
()
@
pytest
.
mark
.
dist
def
test_sharded_optim
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment