Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
17a3c685
Unverified
Commit
17a3c685
authored
Nov 30, 2022
by
HELSON
Committed by
GitHub
Nov 30, 2022
Browse files
[zero] fix unit-tests (#2039)
parent
eb7742a4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
44 deletions
+44
-44
tests/components_to_test/utils/executor.py
tests/components_to_test/utils/executor.py
+3
-4
tests/test_gemini/test_mem_tracer.py
tests/test_gemini/test_mem_tracer.py
+1
-1
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+7
-3
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+33
-36
No files found.
tests/components_to_test/utils/executor.py
View file @
17a3c685
import
torch
import
torch
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
use_init_ctx
=
Fals
e
)
->
torch
.
Tensor
:
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
optimizer
=
Non
e
)
->
torch
.
Tensor
:
"""run_fwd_bwd
"""run_fwd_bwd
run fwd and bwd for the model
run fwd and bwd for the model
...
@@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
...
@@ -10,7 +10,6 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
data (torch.Tensor): input data
data (torch.Tensor): input data
label (torch.Tensor): label
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
criterion (Optional[Callable]): a function of criterion
use_init_ctx (bool, optional): whether the model is initialized under the contxt of ColoInitCtx. Defaults to False.
Returns:
Returns:
torch.Tensor: loss of fwd
torch.Tensor: loss of fwd
...
@@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
...
@@ -23,8 +22,8 @@ def run_fwd_bwd(model, data, label, criterion, use_init_ctx=False) -> torch.Tens
loss
=
model
(
data
,
label
)
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
loss
=
loss
.
float
()
if
use_init_ctx
:
if
optimizer
:
model
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
return
loss
return
loss
tests/test_gemini/test_mem_tracer.py
View file @
17a3c685
...
@@ -33,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
...
@@ -33,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
data
=
data
.
cuda
()
data
=
data
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
use_init_ctx
=
False
)
run_fwd_bwd
(
model
,
data
,
label
,
criterion
)
model
.
_ophook_list
[
0
].
print_non_model_data
()
model
.
_ophook_list
[
0
].
print_non_model_data
()
...
...
tests/test_gemini/update/test_fwd_bwd.py
View file @
17a3c685
...
@@ -10,6 +10,8 @@ import colossalai
...
@@ -10,6 +10,8 @@ import colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
...
@@ -55,6 +57,8 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
...
@@ -55,6 +57,8 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
1
)
pg
=
ProcessGroup
()
pg
=
ProcessGroup
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
...
@@ -71,9 +75,9 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
...
@@ -71,9 +75,9 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if
i
>
0
:
if
i
>
0
:
break
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert
torch
.
equal
(
torch_loss
,
loss
)
assert
torch
.
equal
(
torch_loss
,
loss
)
...
...
tests/test_gemini/update/test_optim.py
View file @
17a3c685
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.testing
import
assert_close
import
colossalai
import
colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
...
@@ -20,7 +21,7 @@ from colossalai.utils.cuda import get_current_device
...
@@ -20,7 +21,7 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
...
@@ -35,27 +36,31 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -35,27 +36,31 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
torch
.
all
close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
)
,
"parameter '{}' has problem."
.
format
(
key
)
assert
_
close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
)
# 'gpt2', 'bert',
# 'gpt2', 'bert',
TEST_MODELS
=
[
'gpt2'
,
'bert'
]
TEST_MODELS
=
[
'gpt2'
,
'bert'
]
# TEST
_MODELS = ['simple_net']
EXAMPLE
_MODELS
=
[
'simple_net'
]
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
])
@
parameterize
(
'model_name'
,
TEST_MODELS
)
@
parameterize
(
'model_name'
,
TEST_MODELS
)
def
exam_model_step
(
placement_policy
,
model_name
:
str
):
def
exam_model_step
(
placement_policy
,
model_name
:
str
):
set_seed
(
42
)
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
128
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_
p
.
data
.
copy_
(
p
.
data
)
p
.
data
.
copy_
(
torch_
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
...
@@ -70,12 +75,7 @@ def exam_model_step(placement_policy, model_name: str):
...
@@ -70,12 +75,7 @@ def exam_model_step(placement_policy, model_name: str):
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
128
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
model
.
eval
()
model
.
eval
()
torch_model
.
eval
()
torch_model
.
eval
()
...
@@ -84,15 +84,13 @@ def exam_model_step(placement_policy, model_name: str):
...
@@ -84,15 +84,13 @@ def exam_model_step(placement_policy, model_name: str):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
zero_optim
.
zero_grad
()
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert_close
(
torch_loss
,
loss
)
assert
torch
.
allclose
(
torch_loss
,
loss
,
rtol
=
1e-3
,
atol
=
1e-2
),
f
"
{
torch_loss
}
vs
{
loss
}
"
# debug_print([0], zero_logits, torch_logits)
zero_optim
.
step
()
zero_optim
.
step
()
torch_optim
.
step
()
torch_optim
.
step
()
...
@@ -101,31 +99,29 @@ def exam_model_step(placement_policy, model_name: str):
...
@@ -101,31 +99,29 @@ def exam_model_step(placement_policy, model_name: str):
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'model_name'
,
TEST
_MODELS
)
@
parameterize
(
'model_name'
,
EXAMPLE
_MODELS
)
def
exam_tiny_example
(
placement_policy
,
model_name
:
str
):
def
exam_tiny_example
(
placement_policy
,
model_name
:
str
):
set_seed
(
4
2
)
set_seed
(
2
008
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
2
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_
p
.
data
.
copy_
(
p
.
data
)
p
.
data
.
copy_
(
torch_
p
.
data
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_mb
=
1
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_mb
=
1
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
model
.
eval
()
model
.
eval
()
torch_model
.
eval
()
torch_model
.
eval
()
...
@@ -134,14 +130,15 @@ def exam_tiny_example(placement_policy, model_name: str):
...
@@ -134,14 +130,15 @@ def exam_tiny_example(placement_policy, model_name: str):
if
i
>
2
:
if
i
>
2
:
break
break
input_ids
=
input_ids
.
cuda
()
label
=
label
.
cuda
()
zero_optim
.
zero_grad
()
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
False
)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
.
cuda
(),
label
.
cuda
(),
criterion
,
use_init_ctx
=
True
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert_close
(
torch_loss
,
loss
)
assert
torch
.
allclose
(
torch_loss
,
loss
,
rtol
=
1e-3
,
atol
=
1e-2
),
f
"
{
torch_loss
}
vs
{
loss
}
"
# debug_print([0], zero_logits, torch_logits)
zero_optim
.
step
()
zero_optim
.
step
()
torch_optim
.
step
()
torch_optim
.
step
()
...
@@ -165,4 +162,4 @@ def test_optim(world_size):
...
@@ -165,4 +162,4 @@ def test_optim(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_optim
(
2
)
test_optim
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment