Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0fab86b1
Unverified
Commit
0fab86b1
authored
May 06, 2022
by
Ziyue Jiang
Committed by
GitHub
May 06, 2022
Browse files
[Tensor] add a basic bert. (#911)
* add base bert test * Add bert test * polish * remove test_bert * polish
parent
ab95ec9a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
129 additions
and
3 deletions
+129
-3
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+129
-3
No files found.
tests/test_tensor/test_model.py
View file @
0fab86b1
...
...
@@ -17,6 +17,64 @@ import random
import
os
import
numpy
as
np
# Hack huggingface Bert ModelOutput
# Make it available to our ColoTensor
from
transformers.file_utils
import
ModelOutput
from
dataclasses
import
fields
def
post_init_colo
(
self
):
class_fields
=
fields
(
self
)
# Safety and consistency checks
if
not
len
(
class_fields
):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no fields."
)
if
not
all
(
field
.
default
is
None
for
field
in
class_fields
[
1
:]):
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
should not have more than one required field."
)
first_field
=
getattr
(
self
,
class_fields
[
0
].
name
)
other_fields_are_none
=
all
(
getattr
(
self
,
field
.
name
)
is
None
for
field
in
class_fields
[
1
:])
def
is_tensor_with_colo
(
x
):
"""
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
"""
if
isinstance
(
x
,
torch
.
Tensor
):
return
True
return
isinstance
(
x
,
ColoTensor
)
if
other_fields_are_none
and
not
is_tensor_with_colo
(
first_field
):
if
isinstance
(
first_field
,
dict
):
iterator
=
first_field
.
items
()
first_field_iterator
=
True
else
:
try
:
iterator
=
iter
(
first_field
)
first_field_iterator
=
True
except
TypeError
:
first_field_iterator
=
False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if
first_field_iterator
:
for
element
in
iterator
:
if
(
not
isinstance
(
element
,
(
list
,
tuple
))
or
not
len
(
element
)
==
2
or
not
isinstance
(
element
[
0
],
str
)
):
break
setattr
(
self
,
element
[
0
],
element
[
1
])
if
element
[
1
]
is
not
None
:
self
[
element
[
0
]]
=
element
[
1
]
elif
first_field
is
not
None
:
self
[
class_fields
[
0
].
name
]
=
first_field
else
:
for
field
in
class_fields
:
v
=
getattr
(
self
,
field
.
name
)
if
v
is
not
None
:
self
[
field
.
name
]
=
v
ModelOutput
.
__post_init__
=
post_init_colo
# complete the hack
def
set_seed
(
seed
):
random
.
seed
(
seed
)
...
...
@@ -64,7 +122,7 @@ def run_1d_col_tp():
model_torch
=
model_torch
.
cuda
()
# A naive way to set spec for all weights in Linear
for
name
,
p
in
named_params_with_colotensor
(
model
):
for
name
,
p
in
model
.
colo_named_parameters
(
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
...
...
@@ -249,6 +307,60 @@ def run_1d_row_tp():
if
i
>
5
:
break
def
run_bert_1d
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'bert'
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
device
=
get_current_device
()
set_seed
(
1
)
with
ColoInitContext
(
device
=
device
):
model
=
model_builder
(
checkpoint
=
True
)
# parallel_action_list_row = [
# ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
# ]
# spec_row = TensorSpec(parallel_action_list_row)
parallel_action_list_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_col
=
TensorSpec
(
parallel_action_list_col
)
parallel_action_list_embedding_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_col
=
TensorSpec
(
parallel_action_list_embedding_col
)
for
name
,
p
in
model
.
colo_named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
#print(name)
if
'classifier'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
p
.
set_spec
(
spec_col
)
if
'_embeddings'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_embedding_col
)
# for name, p in model.colo_named_parameters():
# if not isinstance(p, ColoTensor):
# continue
# print(f"{name}: is_gathered {p.is_gathered()}")
model
=
model
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
break
data
=
data
.
to
(
device
)
label
=
label
.
to
(
device
)
model
.
train
()
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
...
...
@@ -256,16 +368,30 @@ def run_dist(rank, world_size, port):
run_1d_row_tp
()
run_1d_col_tp
()
def
run_dist_bert
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bert_1d
()
@
pytest
.
mark
.
dist
@
paramet
e
rize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_simple_net
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
dist
#@pytest.mark.parametrize('world_size', [1, 4])
#Don't really add it to pytest now. After finishing Classifier and Loss, I(jzy) will remove this annotation.
@
parameterize
(
'world_size'
,
[
1
])
@
rerun_if_address_is_in_use
()
def
test_bert
(
world_size
):
run_func
=
partial
(
run_dist_bert
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
# test_simple_net()
test_model_parameters
()
#
test_model_parameters()
# test_colo_optimizer()
test_bert
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment