Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
eb7742a4
Unverified
Commit
eb7742a4
authored
Nov 29, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 29, 2022
Browse files
[Gemini] more tests for Gemini (#2038)
* [Gemini] more tests for Gemini * polish code
parent
537e1817
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
26 deletions
+35
-26
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+1
-1
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+28
-21
tests/test_gemini/update/test_zeroddp_state_dict.py
tests/test_gemini/update/test_zeroddp_state_dict.py
+6
-4
No files found.
tests/components_to_test/bert.py
View file @
eb7742a4
...
...
@@ -40,7 +40,7 @@ def get_training_components():
num_layer
=
2
vocab_size
=
32
def
bert_model_builder
(
checkpoint
):
def
bert_model_builder
(
checkpoint
:
bool
=
False
):
config
=
BertConfig
(
vocab_size
=
vocab_size
,
gradient_checkpointing
=
checkpoint
,
hidden_size
=
hidden_dim
,
...
...
tests/test_gemini/update/test_optim.py
View file @
eb7742a4
...
...
@@ -18,8 +18,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
,
tensor_equal
,
tensor_shard_equal
from
tests.test_tensor.common_utils
import
set_seed
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
...
...
@@ -37,19 +38,16 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert
torch
.
allclose
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
return
logits
# 'gpt2', 'bert',
TEST_MODELS
=
[
'gpt2'
,
'bert'
]
# TEST_MODELS = ['simple_net']
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
def
exam_gpt_fwd_bwd
(
placement_policy
):
@
parameterize
(
'model_name'
,
TEST_MODELS
)
def
exam_model_step
(
placement_policy
,
model_name
:
str
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
...
...
@@ -87,9 +85,13 @@ def exam_gpt_fwd_bwd(placement_policy):
if
i
>
2
:
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
assert
torch
.
allclose
(
torch_loss
,
loss
,
rtol
=
1e-3
,
atol
=
1e-2
),
f
"
{
torch_loss
}
vs
{
loss
}
"
# debug_print([0], zero_logits, torch_logits)
zero_optim
.
step
()
...
...
@@ -99,9 +101,10 @@ def exam_gpt_fwd_bwd(placement_policy):
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
def
exam_tiny_example
(
placement_policy
):
@
parameterize
(
'model_name'
,
TEST_MODELS
)
def
exam_tiny_example
(
placement_policy
,
model_name
:
str
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
...
...
@@ -131,9 +134,13 @@ def exam_tiny_example(placement_policy):
if
i
>
2
:
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
assert
torch
.
allclose
(
torch_loss
,
loss
,
rtol
=
1e-3
,
atol
=
1e-2
),
f
"
{
torch_loss
}
vs
{
loss
}
"
# debug_print([0], zero_logits, torch_logits)
zero_optim
.
step
()
...
...
@@ -145,17 +152,17 @@ def exam_tiny_example(placement_policy):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_
gpt_fwd_bwd
()
exam_
model_step
()
exam_tiny_example
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_
g
pt
(
world_size
):
def
test_
o
pt
im
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_
g
pt
(
2
)
test_
o
pt
im
(
2
)
tests/test_gemini/update/test_zeroddp_state_dict.py
View file @
eb7742a4
...
...
@@ -19,9 +19,10 @@ from tests.test_tensor.common_utils import debug_print, set_seed
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_state_dict
(
placement_policy
,
keep_gathered
):
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
def
exam_state_dict
(
placement_policy
,
keep_gathered
,
model_name
:
str
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
...
...
@@ -53,9 +54,10 @@ def exam_state_dict(placement_policy, keep_gathered):
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_load_state_dict
(
placement_policy
,
keep_gathered
):
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
def
exam_load_state_dict
(
placement_policy
,
keep_gathered
,
model_name
:
str
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment