Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b28991dd
Unverified
Commit
b28991dd
authored
Oct 09, 2022
by
HELSON
Committed by
GitHub
Oct 09, 2022
Browse files
[feature] A new ZeRO implementation (#1644)
parent
b1be5b88
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
475 additions
and
124 deletions
+475
-124
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+109
-0
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+118
-0
tests/test_gemini/update/test_search.py
tests/test_gemini/update/test_search.py
+6
-7
tests/test_gemini/update/test_zeroddp_state_dict.py
tests/test_gemini/update/test_zeroddp_state_dict.py
+111
-0
tests/test_gemini/update/test_zerooptim_state_dict.py
tests/test_gemini/update/test_zerooptim_state_dict.py
+97
-0
tests/test_tensor/test_chunk.py
tests/test_tensor/test_chunk.py
+0
-86
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+34
-31
No files found.
tests/test_gemini/update/test_fwd_bwd.py
0 → 100644
View file @
b28991dd
import
pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ColoTensorSpec
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ColoTensor
from
tests.test_tensor.common_utils
import
debug_print
from
time
import
time
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
def
check_grad
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
chunk_manager
=
model
.
chunk_manager
param_list
=
[
p
for
p
in
model
.
parameters
()]
chunk_list
=
chunk_manager
.
get_chunks
(
param_list
)
for
chunk
in
chunk_list
:
chunk_manager
.
access_chunk
(
chunk
)
for
(
p0
,
p1
)
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
torch
.
allclose
(
p0
,
p1
.
grad
,
atol
=
1e-3
,
rtol
=
1e-5
),
"{}"
.
format
(
torch
.
max
(
torch
.
abs
(
p0
-
p1
.
grad
)).
item
())
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
return
logits
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
def
exam_gpt_fwd_bwd
(
placement_policy
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
pg
=
ProcessGroup
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
model
.
backward
(
loss
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
torch
.
allclose
(
logits
,
torch_logits
,
rtol
=
0
),
"{} {} {}"
.
format
(
torch
.
max
(
torch
.
abs
(
logits
-
torch_logits
)).
item
(),
logits
,
torch_logits
)
check_grad
(
model
,
torch_model
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_gpt_fwd_bwd
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_gpt
(
1
)
tests/test_gemini/update/test_optim.py
0 → 100644
View file @
b28991dd
import
pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
tests.test_tensor.common_utils
import
debug_print
from
time
import
time
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
for
key
,
value
in
torch_dict
.
items
():
# key is 'module.model.PARAMETER', so we truncate it
key
=
key
[
7
:]
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
torch
.
allclose
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
return
logits
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
def
exam_gpt_fwd_bwd
(
placement_policy
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
model
.
eval
()
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
# debug_print([0], zero_logits, torch_logits)
zero_optim
.
step
()
torch_optim
.
step
()
check_param
(
model
,
torch_model
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_gpt_fwd_bwd
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_gpt
(
1
)
tests/test_gemini/update/test_search.py
View file @
b28991dd
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
import
colossalai
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.gemini.
update
import
search_chunk_configuration
from
colossalai.gemini.
chunk
import
search_chunk_configuration
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
...
@@ -35,8 +35,7 @@ def exam_search_chunk_size():
...
@@ -35,8 +35,7 @@ def exam_search_chunk_size():
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
init_1d_row_spec
(
model
,
pg_tp
)
init_1d_row_spec
(
model
,
pg_tp
)
config_dict
=
search_chunk_configuration
(
config_dict
=
search_chunk_configuration
(
model
,
model
,
search_range_mb
=
1
,
search_range_mb
=
1
,
search_interval_byte
=
16
,
search_interval_byte
=
16
,
min_chunk_size_mb
=
0
,
min_chunk_size_mb
=
0
,
...
...
tests/test_gemini/update/test_zeroddp_state_dict.py
0 → 100644
View file @
b28991dd
import
pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
functools
import
partial
from
tests.test_tensor.common_utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
tests.test_tensor.common_utils
import
debug_print
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_state_dict
(
placement_policy
,
keep_gathered
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
.
train
()
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
for
key
,
value
in
torch_dict
.
items
():
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
assert
torch
.
equal
(
value
,
temp_zero_value
),
"parameter '{}' has problem."
.
format
(
key
)
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_load_state_dict
(
placement_policy
,
keep_gathered
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
set_seed
(
451
)
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
torch_dict
=
torch_model
.
state_dict
()
model
.
load_state_dict
(
torch_dict
,
strict
=
False
)
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
for
key
,
value
in
torch_dict
.
items
():
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
assert
torch
.
equal
(
value
,
temp_zero_value
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_state_dict
()
exam_load_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_zero_ddp
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_ddp
(
1
)
tests/test_
zero
/test_zero
_
optim_state_dict.py
→
tests/test_
gemini/update
/test_zerooptim_state_dict.py
View file @
b28991dd
...
@@ -2,99 +2,96 @@ import pytest
...
@@ -2,99 +2,96 @@ import pytest
import
colossalai
import
colossalai
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
functools
import
partial
from
functools
import
partial
from
tests.test_tensor.common_utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ProcessGroup
from
tests.test_tensor.common_utils
import
debug_print
def
check_state
(
s1
,
s2
):
for
v1
,
v2
in
zip
(
s1
.
values
(),
s2
.
values
()):
if
isinstance
(
v1
,
torch
.
Tensor
):
v1
=
v1
.
to
(
v2
.
device
)
assert
torch
.
equal
(
v1
,
v2
),
f
'
{
torch
.
sum
((
v1
-
v2
).
abs
())
}
'
else
:
assert
v1
==
v2
def
check_load_state_dict
(
optim
,
torch_optim
):
for
group
,
torch_group
in
zip
(
optim
.
optim
.
param_groups
,
torch_optim
.
param_groups
):
for
p
,
torch_p
in
zip
(
group
[
'params'
],
torch_group
[
'params'
]):
state
=
optim
.
optim
.
state
[
p
]
torch_state
=
torch_optim
.
state
[
torch_p
]
if
p
.
storage
().
size
()
==
0
:
assert
len
(
state
)
==
0
check_state
(
state
,
torch_state
)
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
def
check_state_dict
(
state_dict
,
torch_state_dict
):
for
(
k1
,
s1
),
(
k2
,
s2
)
in
zip
(
state_dict
[
'state'
].
items
(),
torch_state_dict
[
'state'
].
items
()):
assert
k1
==
k2
check_state
(
s1
,
s2
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'only_rank_0'
,
[
False
,
True
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
run_zero_optim_state_dict
(
use_chunk
,
use_zero
,
placement_policy
,
only_rank_0
):
def
exam_zero_optim_state_dict
(
placement_policy
,
keep_gathered
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
pg
=
ProcessGroup
()
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
set_seed
(
451
)
torch_model
=
model_builder
()
# get a different model
for
p
in
torch_model
.
parameters
():
world_size
=
torch
.
distributed
.
get_world_size
()
p
.
grad
=
torch
.
rand_like
(
p
)
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
torch_optim
.
step
()
if
placement_policy
!=
'cuda'
:
torch_state_dict
=
torch_optim
.
state_dict
()
init_device
=
torch
.
device
(
'cpu'
)
optim
.
load_state_dict
(
torch_state_dict
)
else
:
check_load_state_dict
(
optim
,
torch_optim
)
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
state_dict
=
optim
.
state_dict
(
only_rank_0
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
if
not
only_rank_0
or
pg
.
rank
()
==
0
:
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
check_state_dict
(
state_dict
,
torch_state_dict
)
optimizer
=
HybridAdam
(
model
.
parameters
())
optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
32
)
# initialize the link between chunk16 and chunk32
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
model
.
train
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
optim
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optim
.
backward
(
loss
)
optim
.
step
()
optim_state_dict
=
optim
.
state_dict
()
optim
.
load_state_dict
(
optim_state_dict
)
new_state
=
optim
.
state_dict
()[
'state'
]
org_state
=
optim_state_dict
[
'state'
]
for
k
,
v
in
org_state
.
items
():
w
=
new_state
[
k
]
for
n
,
m
in
v
.
items
():
if
isinstance
(
m
,
torch
.
Tensor
):
o
=
w
[
n
]
if
m
.
device
!=
o
.
device
:
o
=
o
.
to
(
m
.
device
)
assert
torch
.
equal
(
m
,
o
)
else
:
assert
m
==
w
[
n
]
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run
_zero_optim_state_dict
()
exam
_zero_optim_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_zero_optim
_state_dict
(
world_size
):
def
test_zero_optim
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_zero_optim
_state_dict
(
2
)
test_zero_optim
(
1
)
tests/test_tensor/test_chunk.py
deleted
100644 → 0
View file @
b1be5b88
import
torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
from
typing
import
List
from
functools
import
partial
from
colossalai.gemini
import
ChunkManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
def
check_has_params
(
params
:
List
[
torch
.
Tensor
],
has_tensors
:
List
[
bool
]):
for
p
,
has_tensor
in
zip
(
params
,
has_tensors
):
if
has_tensor
:
assert
p
.
storage
().
size
()
>
0
assert
p
.
device
.
type
==
'cuda'
else
:
assert
p
.
storage
().
size
()
==
0
# HAS_TENSORS[use_chunk][use_zero]
HAS_TENSORS
=
{
True
:
{
True
:
[[
True
,
True
,
False
],
[
False
,
False
,
True
]],
False
:
[[
True
,
True
,
True
],
[
True
,
True
,
True
]]
},
False
:
{
True
:
[[
True
,
False
,
True
],
[
False
,
True
,
False
]],
False
:
[[
True
,
True
,
True
],
[
True
,
True
,
True
]]
}
}
TOTAL_MEM
=
{
True
:
{
True
:
[
512
,
512
],
False
:
[
1024
,
1024
]},
False
:
{
True
:
[
512
,
256
],
False
:
[
768
,
768
]}}
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
def
run_chunk_zero
(
use_chunk
,
use_zero
):
pg
=
ColoProcessGroup
()
rank
=
pg
.
rank
()
if
rank
==
0
:
print
(
f
'use_chunk=
{
use_chunk
}
, use_zero=
{
use_zero
}
'
)
params
=
[
torch
.
rand
(
8
,
8
)
for
_
in
range
(
3
)]
chunk_size
=
128
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
)
chunk_manager
.
create_group
(
'param'
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
for
p
in
params
:
chunk_manager
.
append_tensor
(
p
,
'param'
)
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
]
chunks
=
chunk_manager
.
get_chunks
(
params
)
for
chunk
in
chunks
:
chunk_manager
.
access_chunk
(
chunk
)
check_has_params
(
params
,
[
True
,
True
,
True
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
False
][
rank
]
for
chunk
in
chunks
:
chunk_manager
.
release_chunk
(
chunk
)
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
],
chunk_manager
.
total_mem
[
'cuda'
]
for
chunk
in
chunks
:
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
],
chunk_manager
.
total_mem
[
'cuda'
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_chunk_zero
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_chunk_mapping
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_chunk_mapping
(
2
)
tests/test_tensor/test_
zero_optim
.py
→
tests/test_tensor/test_
tp_with_zero
.py
View file @
b28991dd
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
colossalai.gemini
.chunk
import
ChunkManager
,
search_chunk_configuration
from
functools
import
partial
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute
...
@@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_parameters
(),
torch_model
.
named_parameters
()):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
if
p
.
storage
().
size
()
>
0
:
torch_dict
=
torch_model
.
state_dict
()
assert
p
.
dtype
==
torch
.
float16
assert
tensor_shard_equal
(
tp
.
to
(
dtype
=
p
.
dtype
,
device
=
p
.
device
),
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
f
'
{
tp
}
vs
{
p
}
\n
{
n
}
:
\n\t
{
tp
.
shape
}
vs
{
p
.
shape
}
'
for
key
,
value
in
torch_dict
.
items
():
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
# key is 'module.model.PARAMETER', so we truncate it
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_parameters
(),
torch_model
.
named_parameters
()):
key
=
key
[
7
:]
if
p
.
grad
is
not
None
:
if
key
==
'model.lm_head.weight'
:
assert
tensor_shard_equal
(
tp
.
grad
.
to
(
dtype
=
p
.
grad
.
dtype
,
device
=
p
.
grad
.
device
),
p
.
grad
,
continue
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
\
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
f
'
{
tp
.
grad
}
vs
{
p
.
grad
}
\n
{
n
}
:
\n\t
{
tp
.
grad
.
shape
}
vs
{
p
.
grad
.
shape
}
in
{
pg
.
rank
()
}
'
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
tensor_shard_equal
(
value
,
temp_zero_value
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
\
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
...
@@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup):
...
@@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p
.
set_tensor_spec
(
*
spec
)
p
.
set_tensor_spec
(
*
spec
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
def
run_gpt
(
use_chunk
,
use_zero
,
placement_policy
,
tp_init_spec_func
=
None
):
def
run_gpt
(
placement_policy
,
tp_init_spec_func
=
None
):
set_seed
(
42
)
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
...
@@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if
tp_init_spec_func
:
if
tp_init_spec_func
:
tp_init_spec_func
(
model
,
pg
)
tp_init_spec_func
(
model
,
pg
)
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
dp_world_size
=
pg
.
dp_world_size
()
chunk_manager
=
ChunkManager
(
chunk_size
,
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
pg
,
config_dict
[
dp_world_size
][
'chunk_size'
]
=
5000
enable_distributed_storage
=
use_zero
,
config_dict
[
dp_world_size
][
'keep_gathered'
]
=
False
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
1
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
...
@@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
...
@@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
print
(
chunk_manager
)
print
(
chunk_manager
)
check_param
_equal
(
model
,
torch_model
,
pg
)
check_param
(
model
,
torch_model
,
pg
)
model
.
eval
()
model
.
eval
()
torch_model
.
eval
()
torch_model
.
eval
()
...
@@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
...
@@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if
i
>
2
:
if
i
>
2
:
break
break
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
run_fwd_bwd
(
model
,
criterion
,
optim
,
input_ids_colo
,
attn_mask
)
zero_
logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_
optim
,
input_ids_colo
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
t
ensor_equal
(
logits
,
torch_logits
)
assert
t
orch
.
allclose
(
zero_
logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
check_grad_equal
(
model
,
torch_model
,
pg
)
optim
.
step
()
zero_
optim
.
step
()
torch_optim
.
step
()
torch_optim
.
step
()
check_param
_equal
(
model
,
torch_model
,
pg
)
check_param
(
model
,
torch_model
,
pg
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment