Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a0e59716
Unverified
Commit
a0e59716
authored
Apr 27, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 27, 2022
Browse files
[Tensor] test model check results for a simple net (#887)
parent
72cdc068
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
10 deletions
+44
-10
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+44
-10
No files found.
tests/test_tensor/test_
net_tp
.py
→
tests/test_tensor/test_
model
.py
View file @
a0e59716
...
...
@@ -2,6 +2,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
...
...
@@ -9,15 +10,30 @@ from colossalai.utils import free_port
from
colossalai.utils
import
ColoInitContext
from
colossalai.tensor
import
named_params_with_colotensor
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
import
random
import
os
import
numpy
as
np
def
run_simple_net
():
def
set_seed
(
seed
):
random
.
seed
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
def
run_1d_row_tp
():
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
...
...
@@ -26,6 +42,11 @@ def run_simple_net():
]
spec
=
TensorSpec
(
parallel_action_list
)
set_seed
(
1
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
model_torch
=
model_torch
.
cuda
()
# A naive way to set spec for all weights in Linear
for
name
,
p
in
named_params_with_colotensor
(
model
):
if
not
isinstance
(
p
,
ColoTensor
):
...
...
@@ -33,15 +54,16 @@ def run_simple_net():
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
p
.
set_spec
(
spec
)
model
.
cuda
()
for
param
in
named_params_with_colotensor
(
model
):
print
(
param
)
model
=
model
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
...
...
@@ -49,22 +71,34 @@ def run_simple_net():
output
=
model
(
data
,
label
)
loss
=
output
print
(
loss
.
torch_tensor
())
# For reference
if
rank
==
0
:
if
criterion
:
output_torch
=
model_torch
(
data
)
loss_torch
=
criterion
(
output_torch
,
label
)
else
:
output_torch
=
model_torch
(
data
,
label
)
loss_torch
=
output_torch
if
rank
==
0
:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
assert
torch
.
allclose
(
loss
.
torch_tensor
(),
loss_torch
,
rtol
=
1e-2
)
loss
.
backward
()
if
rank
==
0
:
loss_torch
.
backward
()
if
i
>
5
:
break
# TODO(jzy) check the results with col.nn.Linear?
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_
simple_net
()
run_
1d_row_tp
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment