Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7977422a
Commit
7977422a
authored
Mar 09, 2022
by
jiaruifang
Committed by
Frank Lee
Mar 11, 2022
Browse files
add bert for unitest and sharded model is not able to pass the bert case
parent
3d5d64bd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
104 additions
and
14 deletions
+104
-14
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-3
tests/components_to_test/__init__.py
tests/components_to_test/__init__.py
+1
-1
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+69
-0
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+4
-3
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+1
-1
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+24
-6
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7977422a
...
@@ -17,8 +17,9 @@ from colossalai.zero.sharded_param import ShardedParamV2
...
@@ -17,8 +17,9 @@ from colossalai.zero.sharded_param import ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
from
._zero3_utils
import
(
cast_tensor_to_fp32
,
chunk_and_pad
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
...
@@ -79,7 +80,8 @@ class ShardedModelV2(nn.Module):
...
@@ -79,7 +80,8 @@ class ShardedModelV2(nn.Module):
self
.
_require_backward_grad_sync
:
bool
=
True
self
.
_require_backward_grad_sync
:
bool
=
True
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
# TODO args can be Long!
# args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
return
outputs
...
...
tests/components_to_test/__init__.py
View file @
7977422a
from
.
import
repeated_computed_layer
,
resnet
,
nested_model
from
.
import
repeated_computed_layer
,
resnet
,
nested_model
,
bert
tests/components_to_test/bert.py
0 → 100644
View file @
7977422a
import
torch
import
transformers
from
transformers
import
BertConfig
,
BertForSequenceClassification
from
packaging
import
version
from
torch.utils.data
import
SequentialSampler
from
.registry
import
non_distributed_component_funcs
def
get_bert_data_loader
(
batch_size
,
total_samples
,
sequence_length
,
device
=
torch
.
device
(
'cpu:0'
),
is_distrbuted
=
False
,
):
train_data
=
torch
.
randint
(
low
=
0
,
high
=
1000
,
size
=
(
total_samples
,
sequence_length
),
device
=
device
,
dtype
=
torch
.
long
,
)
train_label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
total_samples
,),
device
=
device
,
dtype
=
torch
.
long
)
train_dataset
=
torch
.
utils
.
data
.
TensorDataset
(
train_data
,
train_label
)
if
is_distrbuted
:
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
else
:
sampler
=
SequentialSampler
(
train_dataset
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
)
return
train_loader
@
non_distributed_component_funcs
.
register
(
name
=
'bert'
)
def
get_training_components
():
hidden_dim
=
8
num_head
=
4
sequence_length
=
12
num_layer
=
2
def
bert_model_builder
(
checkpoint
):
config
=
BertConfig
(
gradient_checkpointing
=
checkpoint
,
hidden_size
=
hidden_dim
,
intermediate_size
=
hidden_dim
*
4
,
num_attention_heads
=
num_head
,
max_position_embeddings
=
sequence_length
,
num_hidden_layers
=
num_layer
,
)
print
(
'building BertForSequenceClassification model'
)
model
=
BertForSequenceClassification
(
config
)
if
checkpoint
and
version
.
parse
(
transformers
.
__version__
)
>=
version
.
parse
(
"4.11.0"
):
model
.
gradient_checkpointing_enable
()
return
model
trainloader
=
get_bert_data_loader
(
batch_size
=
2
,
total_samples
=
10000
,
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
testloader
=
get_bert_data_loader
(
batch_size
=
2
,
total_samples
=
10000
,
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
def
get_optim
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
None
return
bert_model_builder
,
trainloader
,
testloader
,
get_optim
,
criterion
tests/test_engine/test_engine.py
View file @
7977422a
...
@@ -15,6 +15,7 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
...
@@ -15,6 +15,7 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def
run_train
():
def
run_train
():
assert
non_distributed_component_funcs
.
get_callable
(
'bert'
)
for
get_components_func
in
non_distributed_component_funcs
:
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
train_dataloader
,
_
,
optimizer_builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
optimizer_builder
,
criterion
=
get_components_func
()
...
@@ -71,9 +72,9 @@ def run_engine(rank, world_size, port):
...
@@ -71,9 +72,9 @@ def run_engine(rank, world_size, port):
# init dist env
# init dist env
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_no_amp
()
run_with_no_amp
()
run_with_torch_amp
()
#
run_with_torch_amp()
run_with_apex_amp
()
#
run_with_apex_amp()
run_with_naive_amp
()
#
run_with_naive_amp()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/common.py
View file @
7977422a
...
@@ -75,7 +75,7 @@ def check_grads_padding(model, zero_model, loose=False):
...
@@ -75,7 +75,7 @@ def check_grads_padding(model, zero_model, loose=False):
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
),
f
'
{
grad
}
vs
{
zero_grad
}
'
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
),
f
'
diff:
{
grad
-
zero_grad
}
'
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
7977422a
...
@@ -31,14 +31,25 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
...
@@ -31,14 +31,25 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
loss
.
backward
()
loss
.
backward
()
def
run_bert_fwd_bwd
(
model
,
data
,
label
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
output
=
model
(
input_ids
=
data
,
labels
=
label
)
loss
=
output
[
0
]
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
TensorShardStrategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
().
half
().
cuda
()
model
=
model
(
checkpoint
=
True
).
half
().
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
model
=
DDP
(
model
)
...
@@ -46,9 +57,16 @@ def run_dist(rank, world_size, port):
...
@@ -46,9 +57,16 @@ def run_dist(rank, world_size, port):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
if
model_name
==
'bert'
:
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
run_bert_fwd_bwd
(
model
,
data
,
label
,
False
)
run_bert_fwd_bwd
(
zero_model
,
data
,
label
,
False
)
else
:
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment