Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
830d3bca
Unverified
Commit
830d3bca
authored
May 13, 2022
by
Ziyue Jiang
Committed by
GitHub
May 13, 2022
Browse files
[Tensor] add optimizer to bert test (#933)
* add optimizer to bert test * polish
parent
7edb3819
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
9 deletions
+40
-9
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+40
-9
No files found.
tests/test_tensor/test_model.py
View file @
830d3bca
...
...
@@ -96,6 +96,15 @@ def run_1d_hybrid_tp(model_name):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_torch
.
cuda
()
colo_optimizer_torch
=
ColoOptimizer
(
dict
(
model_torch
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
# Make two models have the same init params
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
if
'bert'
==
model_name
:
parallel_action_list_row
=
[
ParallelAction
(
priority
=
1
,
...
...
@@ -176,14 +185,15 @@ def run_1d_hybrid_tp(model_name):
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
p
.
set_spec
(
spec_classifier_col
)
set_seed
(
1
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_torch
.
cuda
()
model
=
model
.
cuda
()
colo_optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
model
.
eval
()
colo_optimizer
.
zero_grad
()
if
rank
==
0
:
model_torch
.
eval
()
colo_optimizer_torch
.
zero_grad
()
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
...
...
@@ -210,12 +220,33 @@ def run_1d_hybrid_tp(model_name):
if
rank
==
0
:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
with
torch
.
no_grad
():
assert
torch
.
allclose
(
loss
.
torch_tensor
(),
loss_torch
,
rtol
=
1e-2
)
loss
.
backward
()
colo_optimizer
.
step
()
if
rank
==
0
:
loss_torch
.
backward
()
colo_optimizer_torch
.
step
()
with
torch
.
no_grad
():
# check param
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
if
p1
.
size
()
==
p2
.
size
():
assert
torch
.
allclose
(
p1
,
p2
)
else
:
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
if
p1
.
size
(
-
1
)
<
p2
.
size
(
-
1
):
# col
world_size
=
p2
.
size
(
-
1
)
//
p1
.
size
(
-
1
)
split_p2
=
torch
.
chunk
(
p2
,
world_size
,
dim
=-
1
)[
0
]
elif
p1
.
size
(
0
)
<
p2
.
size
(
0
):
# row
world_size
=
p2
.
size
(
0
)
//
p1
.
size
(
0
)
split_p2
=
torch
.
chunk
(
p2
,
world_size
,
dim
=
0
)[
0
]
assert
torch
.
allclose
(
p1
,
split_p2
)
if
i
>
5
:
break
...
...
@@ -428,5 +459,5 @@ def _test_pretrain_load(world_size):
if
__name__
==
'__main__'
:
# test_model_parameters()
# test_colo_optimizer()
#
test_model()
_test_pretrain_load
(
4
)
test_model
(
4
)
#
_test_pretrain_load(4)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment