Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2e9cbfca
Unverified
Commit
2e9cbfca
authored
Nov 24, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 24, 2022
Browse files
[Gemini] add unitests to check gemini correctness (#2015)
parent
0b0d8f9e
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
133 additions
and
52 deletions
+133
-52
tests/components_to_test/__init__.py
tests/components_to_test/__init__.py
+1
-0
tests/components_to_test/gpt.py
tests/components_to_test/gpt.py
+7
-5
tests/components_to_test/inline_op_model.py
tests/components_to_test/inline_op_model.py
+1
-1
tests/components_to_test/utils/__init__.py
tests/components_to_test/utils/__init__.py
+2
-1
tests/components_to_test/utils/executor.py
tests/components_to_test/utils/executor.py
+15
-0
tests/test_gemini/test_gemini_train.py
tests/test_gemini/test_gemini_train.py
+67
-0
tests/test_gemini/test_mem_tracer.py
tests/test_gemini/test_mem_tracer.py
+3
-13
tests/test_gemini/test_param_op.py
tests/test_gemini/test_param_op.py
+1
-1
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+5
-5
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+8
-8
tests/test_gemini/update/test_zerooptim_state_dict.py
tests/test_gemini/update/test_zerooptim_state_dict.py
+2
-2
tests/test_tensor/model/test_gpt2.py
tests/test_tensor/model/test_gpt2.py
+16
-11
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+5
-5
No files found.
tests/components_to_test/__init__.py
View file @
2e9cbfca
from
.
import
bert
,
gpt
,
inline_op_model
,
nested_model
,
no_leaf_module
,
repeated_computed_layer
,
resnet
,
simple_net
from
.utils
import
run_fwd_bwd
tests/components_to_test/gpt.py
View file @
2e9cbfca
import
torch
import
torch.nn
as
nn
from
.registry
import
non_distributed_component_funcs
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
from
.utils.dummy_data_generator
import
DummyDataGenerator
from
colossalai.utils.cuda
import
get_current_device
from
.registry
import
non_distributed_component_funcs
from
.utils.dummy_data_generator
import
DummyDataGenerator
class
DummyDataLoader
(
DummyDataGenerator
):
vocab_size
=
128
...
...
@@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator):
input_ids
=
torch
.
randint
(
0
,
DummyDataLoader
.
vocab_size
,
(
DummyDataLoader
.
batch_size
,
DummyDataLoader
.
seq_len
),
device
=
get_current_device
())
attention_mask
=
torch
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
return
input_ids
,
input_ids
class
GPTLMModel
(
nn
.
Module
):
...
...
@@ -43,8 +44,9 @@ class GPTLMModel(nn.Module):
if
checkpoint
:
self
.
model
.
gradient_checkpointing_enable
()
def
forward
(
self
,
input_ids
,
attention_mask
):
def
forward
(
self
,
input_ids
):
# Only return lm_logits
attention_mask
=
torch
.
ones_like
(
input_ids
)
return
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
use_cache
=
not
self
.
checkpoint
)[
0
]
...
...
tests/components_to_test/inline_op_model.py
View file @
2e9cbfca
...
...
@@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator):
return
data
,
label
@
non_distributed_component_funcs
.
register
(
name
=
'inline_op_mod
ul
e'
)
@
non_distributed_component_funcs
.
register
(
name
=
'inline_op_mode
l
'
)
def
get_training_components
():
def
model_builder
(
checkpoint
=
True
):
...
...
tests/components_to_test/utils/__init__.py
View file @
2e9cbfca
from
.dummy_data_generator
import
DummyDataGenerator
from
.executor
import
run_fwd_bwd
tests/components_to_test/utils/executor.py
0 → 100644
View file @
2e9cbfca
import
torch
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
,
use_init_ctx
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
use_init_ctx
:
model
.
backward
(
loss
)
else
:
loss
.
backward
()
tests/test_gemini/test_gemini_train.py
0 → 100644
View file @
2e9cbfca
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_gemini_fwd_bwd
(
rank
,
world_size
,
port
,
model_name
:
str
,
iter_num
=
2
):
PLACEMENT_POLICY
=
'cuda'
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
# build torch model
model_torch
=
model_builder
(
checkpoint
=
False
).
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>=
iter_num
:
break
run_fwd_bwd
(
model_torch
,
data
.
cuda
(),
label
.
cuda
(),
criterion
,
False
,
use_init_ctx
=
False
)
# build CAI model
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
False
)
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
,
search_chunk_configuration
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
GeminiManager
.
get_default_device
(
PLACEMENT_POLICY
))
gemini_manager
=
GeminiManager
(
PLACEMENT_POLICY
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
model
.
train
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>=
iter_num
:
break
run_fwd_bwd
(
model
,
data
.
cuda
(),
label
.
cuda
(),
criterion
,
False
,
use_init_ctx
=
True
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
torch
.
allclose
(
p1
.
to
(
torch
.
float
),
p2
.
to
(
torch
.
float
))
print
(
f
'pass test
{
model_name
}
'
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
'bert'
])
@
rerun_if_address_is_in_use
()
def
test_gemini_train
(
model_name
,
iter_num
=
2
):
run_func
=
partial
(
run_gemini_fwd_bwd
,
world_size
=
1
,
port
=
free_port
(),
model_name
=
model_name
,
iter_num
=
iter_num
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
'__main__'
:
# for model_name in ["bert", "resnet18", "inline_op_model"]:
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
# repeated_computed_layer, resnet, simple_net
for
model_name
in
[
"nested_model"
,
"no_leaf_module"
]:
test_gemini_train
(
model_name
=
model_name
,
iter_num
=
4
)
tests/test_gemini/test_mem_tracer.py
View file @
2e9cbfca
...
...
@@ -8,20 +8,10 @@ import colossalai
from
colossalai.gemini.memory_tracer
import
MemtracerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
model
.
backward
(
loss
)
def
run_tracer
(
rank
,
world_size
,
port
,
use_grad_check
=
True
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'bert'
]
...
...
@@ -43,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
data
=
data
.
cuda
()
label
=
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
,
use_init_ctx
=
False
)
model
.
_ophook_list
[
0
].
print_non_model_data
()
...
...
@@ -58,4 +48,4 @@ def test_tracer(world_size, use_grad_check):
if
__name__
==
'__main__'
:
test_tracer
(
1
)
test_tracer
(
1
,
True
)
tests/test_gemini/test_param_op.py
View file @
2e9cbfca
...
...
@@ -50,7 +50,7 @@ def run_model(model, inputs, label, criterion, use_param_hook=False):
def
test_base_param_hook
():
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'inline_op_mod
ul
e'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'inline_op_mode
l
'
]
# test_models = ['bert']
for
model_name
in
test_models
:
...
...
tests/test_gemini/update/test_fwd_bwd.py
View file @
2e9cbfca
...
...
@@ -30,9 +30,9 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert
torch
.
allclose
(
p0
,
p1
.
grad
,
atol
=
1e-3
,
rtol
=
1e-5
),
"{}"
.
format
(
torch
.
max
(
torch
.
abs
(
p0
-
p1
.
grad
)).
item
())
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
...
...
@@ -71,16 +71,16 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather):
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
model
.
backward
(
loss
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
logits
,
torch_logits
,
rtol
=
0
),
"{} {} {}"
.
format
(
torch
.
max
(
torch
.
abs
(
logits
-
torch_logits
)).
item
(),
logits
,
torch_logits
)
...
...
tests/test_gemini/update/test_optim.py
View file @
2e9cbfca
...
...
@@ -37,9 +37,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert
torch
.
allclose
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
...
...
@@ -83,12 +83,12 @@ def exam_gpt_fwd_bwd(placement_policy):
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
# debug_print([0], zero_logits, torch_logits)
...
...
@@ -127,12 +127,12 @@ def exam_tiny_example(placement_policy):
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
# debug_print([0], zero_logits, torch_logits)
...
...
tests/test_gemini/update/test_zerooptim_state_dict.py
View file @
2e9cbfca
...
...
@@ -50,11 +50,11 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
model
.
train
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
optim
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optim
.
backward
(
loss
)
...
...
tests/test_tensor/model/test_gpt2.py
View file @
2e9cbfca
import
pytest
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
import
pytest
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
split_param_col_tp1d
,
split_param_row_tp1d
,
debug_print
from
tests.test_tensor.common_utils
import
(
debug_print
,
set_seed
,
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_equal
,
tensor_shard_equal
,
)
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
...
...
@@ -107,10 +112,10 @@ def run_gpt(init_spec_func, use_ddp):
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
torch
.
distributed
.
barrier
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
model
(
colo_input
,
attn_mask
)
torch_logits
=
torch_model
(
input_ids
,
attn_mask
)
logits
=
model
(
colo_input
)
torch_logits
=
torch_model
(
input_ids
)
assert
tensor_equal
(
torch_logits
,
logits
),
f
"
{
torch_logits
-
logits
}
"
loss
=
criterion
(
logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
...
...
tests/test_tensor/test_tp_with_zero.py
View file @
2e9cbfca
...
...
@@ -36,9 +36,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
...
...
@@ -117,12 +117,12 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids_colo
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids_colo
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
zero_optim
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment