Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
17a3c685
Unverified
Commit
17a3c685
authored
Nov 30, 2022
by
HELSON
Committed by
GitHub
Nov 30, 2022
Browse files
[zero] fix unit-tests (#2039)
parent
eb7742a4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
44 deletions
+44
-44
tests/components_to_test/utils/executor.py
tests/components_to_test/utils/executor.py
+3
-4
tests/test_gemini/test_mem_tracer.py
tests/test_gemini/test_mem_tracer.py
+1
-1
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+7
-3
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+33
-36
No files found.
tests/components_to_test/utils/executor.py
View file @
17a3c685
import
torch
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
use_init_ctx
=
Fals
e
)
->
torch
.
Tensor
:
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
=
Non
e
)
->
torch
.
Tensor
:
"""run_fwd_bwd
run fwd and bwd for the model
...
...
@@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
use_init_ctx (bool, optional): whether the model is initialized under the contxt of ColoInitCtx. Defaults to False.
Returns:
torch.Tensor: loss of fwd
...
...
@@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
use_init_ctx
:
model
.
backward
(
loss
)
if
optimizer
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
return
loss
tests/test_gemini/test_mem_tracer.py
View file @
17a3c685
...
...
@@ -33,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
data
=
data
.
cuda
()
label
=
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
use_init_ctx
=
False
)
run_fwd_bwd
(
model
,
data
,
label
,
criterion
)
model
.
_ophook_list
[
0
].
print_non_model_data
()
...
...
tests/test_gemini/update/test_fwd_bwd.py
View file @
17a3c685
...
...
@@ -10,6 +10,8 @@ import colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
...
...
@@ -55,6 +57,8 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
1
)
pg
=
ProcessGroup
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
...
...
@@ -71,9 +75,9 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if
i
>
0
:
break
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert
torch
.
equal
(
torch_loss
,
loss
)
...
...
tests/test_gemini/update/test_optim.py
View file @
17a3c685
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.amp
import
convert_to_apex_amp
...
...
@@ -20,7 +21,7 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
...
...
@@ -35,27 +36,31 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
torch
.
all
close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
)
,
"parameter '{}' has problem."
.
format
(
key
)
assert
_
close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
)
# 'gpt2', 'bert',
TEST_MODELS
=
[
'gpt2'
,
'bert'
]
# TEST
_MODELS = ['simple_net']
EXAMPLE
_MODELS
=
[
'simple_net'
]
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
])
@
parameterize
(
'model_name'
,
TEST_MODELS
)
def
exam_model_step
(
placement_policy
,
model_name
:
str
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
128
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_
p
.
data
.
copy_
(
p
.
data
)
p
.
data
.
copy_
(
torch_
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
...
...
@@ -70,12 +75,7 @@ def exam_model_step(placement_policy, model_name: str):
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
128
)
model
.
eval
()
torch_model
.
eval
()
...
...
@@ -84,15 +84,13 @@ def exam_model_step(placement_policy, model_name: str):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
assert
torch
.
allclose
(
torch_loss
,
loss
,
rtol
=
1e-3
,
atol
=
1e-2
),
f
"
{
torch_loss
}
vs
{
loss
}
"
# debug_print([0], zero_logits, torch_logits)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert_close
(
torch_loss
,
loss
)
zero_optim
.
step
()
torch_optim
.
step
()
...
...
@@ -101,31 +99,29 @@ def exam_model_step(placement_policy, model_name: str):
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'model_name'
,
TEST
_MODELS
)
@
parameterize
(
'model_name'
,
EXAMPLE
_MODELS
)
def
exam_tiny_example
(
placement_policy
,
model_name
:
str
):
set_seed
(
4
2
)
set_seed
(
2
008
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
2
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_
p
.
data
.
copy_
(
p
.
data
)
p
.
data
.
copy_
(
torch_
p
.
data
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_mb
=
1
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
model
.
eval
()
torch_model
.
eval
()
...
...
@@ -134,14 +130,15 @@ def exam_tiny_example(placement_policy, model_name: str):
if
i
>
2
:
break
input_ids
=
input_ids
.
cuda
()
label
=
label
.
cuda
()
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
assert
torch
.
allclose
(
torch_loss
,
loss
,
rtol
=
1e-3
,
atol
=
1e-2
),
f
"
{
torch_loss
}
vs
{
loss
}
"
# debug_print([0], zero_logits, torch_logits)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert_close
(
torch_loss
,
loss
)
zero_optim
.
step
()
torch_optim
.
step
()
...
...
@@ -165,4 +162,4 @@ def test_optim(world_size):
if
__name__
==
'__main__'
:
test_optim
(
2
)
test_optim
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment