Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
534afb01
Unverified
Commit
534afb01
authored
May 09, 2022
by
Jiarui Fang
Committed by
GitHub
May 09, 2022
Browse files
test pretrain loading on multi-process (#922)
parent
c195d281
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
23 deletions
+41
-23
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+41
-23
No files found.
tests/test_tensor/test_model.py
View file @
534afb01
...
...
@@ -278,26 +278,6 @@ def test_colo_optimizer():
if
i
>
5
:
break
def
_test_pretrained
():
from
_utils
import
check_equal
from
transformers
import
BertForMaskedLM
set_seed
(
1
)
model_pretrained
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
)
with
ColoInitContext
(
lazy_memory_allocate
=
False
,
device
=
get_current_device
()):
model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
)
model_pretrained
=
model_pretrained
.
cuda
()
model
=
model
.
cuda
()
dict_pretrained
=
{}
dict_col
=
{}
for
name
,
param
in
model_pretrained
.
named_parameters
():
dict_pretrained
[
name
]
=
param
for
name
,
param
in
model
.
named_parameters
():
dict_col
[
name
]
=
param
for
name
,
param
in
dict_pretrained
.
items
():
check_equal
(
param
,
dict_col
[
name
])
def
run_1d_row_tp
(
model_name
:
str
):
# A simple net with two stacked nn.Linear
...
...
@@ -376,7 +356,29 @@ def run_1d_row_tp(model_name: str):
break
def
run_dist
(
rank
,
world_size
,
port
):
def
_run_pretrain_load
():
from
_utils
import
check_equal
from
transformers
import
BertForMaskedLM
set_seed
(
1
)
model_pretrained
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
)
with
ColoInitContext
(
lazy_memory_allocate
=
False
,
device
=
get_current_device
()):
model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
)
model_pretrained
=
model_pretrained
.
cuda
()
model
=
model
.
cuda
()
dict_pretrained
=
{}
dict_col
=
{}
for
name
,
param
in
model_pretrained
.
named_parameters
():
dict_pretrained
[
name
]
=
param
for
name
,
param
in
model
.
named_parameters
():
dict_col
[
name
]
=
param
for
name
,
param
in
dict_pretrained
.
items
():
check_equal
(
param
,
dict_col
[
name
])
def
run_model_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
name
in
[
'simple_net'
]:
...
...
@@ -390,7 +392,23 @@ def run_dist(rank, world_size, port):
#@parameterize('world_size', [1, 4])
@
rerun_if_address_is_in_use
()
def
test_model
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_model_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
run_pretrain_load_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_pretrain_load
()
# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
_test_pretrain_load
(
world_size
):
run_func
=
partial
(
run_pretrain_load_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
@@ -398,4 +416,4 @@ if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model()
_test_pretrain
e
d
()
_test_pretrain
_loa
d
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment