Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b28991dd
Unverified
Commit
b28991dd
authored
Oct 09, 2022
by
HELSON
Committed by
GitHub
Oct 09, 2022
Browse files
[feature] A new ZeRO implementation (#1644)
parent
b1be5b88
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
475 additions
and
124 deletions
+475
-124
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+109
-0
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+118
-0
tests/test_gemini/update/test_search.py
tests/test_gemini/update/test_search.py
+6
-7
tests/test_gemini/update/test_zeroddp_state_dict.py
tests/test_gemini/update/test_zeroddp_state_dict.py
+111
-0
tests/test_gemini/update/test_zerooptim_state_dict.py
tests/test_gemini/update/test_zerooptim_state_dict.py
+97
-0
tests/test_tensor/test_chunk.py
tests/test_tensor/test_chunk.py
+0
-86
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+34
-31
No files found.
tests/test_gemini/update/test_fwd_bwd.py
0 → 100644
View file @
b28991dd
import
pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ColoTensorSpec
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ColoTensor
from
tests.test_tensor.common_utils
import
debug_print
from
time
import
time
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
def
check_grad
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
chunk_manager
=
model
.
chunk_manager
param_list
=
[
p
for
p
in
model
.
parameters
()]
chunk_list
=
chunk_manager
.
get_chunks
(
param_list
)
for
chunk
in
chunk_list
:
chunk_manager
.
access_chunk
(
chunk
)
for
(
p0
,
p1
)
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
torch
.
allclose
(
p0
,
p1
.
grad
,
atol
=
1e-3
,
rtol
=
1e-5
),
"{}"
.
format
(
torch
.
max
(
torch
.
abs
(
p0
-
p1
.
grad
)).
item
())
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
return
logits
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
def
exam_gpt_fwd_bwd
(
placement_policy
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
pg
=
ProcessGroup
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
model
.
backward
(
loss
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
torch
.
allclose
(
logits
,
torch_logits
,
rtol
=
0
),
"{} {} {}"
.
format
(
torch
.
max
(
torch
.
abs
(
logits
-
torch_logits
)).
item
(),
logits
,
torch_logits
)
check_grad
(
model
,
torch_model
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_gpt_fwd_bwd
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_gpt
(
1
)
tests/test_gemini/update/test_optim.py
0 → 100644
View file @
b28991dd
import
pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
tests.test_tensor.common_utils
import
debug_print
from
time
import
time
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
for
key
,
value
in
torch_dict
.
items
():
# key is 'module.model.PARAMETER', so we truncate it
key
=
key
[
7
:]
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
torch
.
allclose
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
return
logits
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
def
exam_gpt_fwd_bwd
(
placement_policy
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
2
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
model
.
eval
()
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
# debug_print([0], zero_logits, torch_logits)
zero_optim
.
step
()
torch_optim
.
step
()
check_param
(
model
,
torch_model
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_gpt_fwd_bwd
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_gpt
(
1
)
tests/test_gemini/update/test_search.py
View file @
b28991dd
...
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.gemini.
update
import
search_chunk_configuration
from
colossalai.gemini.
chunk
import
search_chunk_configuration
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
...
...
@@ -35,8 +35,7 @@ def exam_search_chunk_size():
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
init_1d_row_spec
(
model
,
pg_tp
)
config_dict
=
search_chunk_configuration
(
model
,
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
16
,
min_chunk_size_mb
=
0
,
...
...
tests/test_gemini/update/test_zeroddp_state_dict.py
0 → 100644
View file @
b28991dd
import
pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
functools
import
partial
from
tests.test_tensor.common_utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
tests.test_tensor.common_utils
import
debug_print
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_state_dict
(
placement_policy
,
keep_gathered
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
torch_model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
.
train
()
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
for
key
,
value
in
torch_dict
.
items
():
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
assert
torch
.
equal
(
value
,
temp_zero_value
),
"parameter '{}' has problem."
.
format
(
key
)
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_load_state_dict
(
placement_policy
,
keep_gathered
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
set_seed
(
451
)
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
torch_dict
=
torch_model
.
state_dict
()
model
.
load_state_dict
(
torch_dict
,
strict
=
False
)
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
for
key
,
value
in
torch_dict
.
items
():
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
assert
torch
.
equal
(
value
,
temp_zero_value
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_state_dict
()
exam_load_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_zero_ddp
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_ddp
(
1
)
tests/test_
zero
/test_zero
_
optim_state_dict.py
→
tests/test_
gemini/update
/test_zerooptim_state_dict.py
View file @
b28991dd
...
...
@@ -2,99 +2,96 @@ import pytest
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
functools
import
partial
from
tests.test_tensor.common_utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
parameterize
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ProcessGroup
def
check_state
(
s1
,
s2
):
for
v1
,
v2
in
zip
(
s1
.
values
(),
s2
.
values
()):
if
isinstance
(
v1
,
torch
.
Tensor
):
v1
=
v1
.
to
(
v2
.
device
)
assert
torch
.
equal
(
v1
,
v2
),
f
'
{
torch
.
sum
((
v1
-
v2
).
abs
())
}
'
else
:
assert
v1
==
v2
def
check_load_state_dict
(
optim
,
torch_optim
):
for
group
,
torch_group
in
zip
(
optim
.
optim
.
param_groups
,
torch_optim
.
param_groups
):
for
p
,
torch_p
in
zip
(
group
[
'params'
],
torch_group
[
'params'
]):
state
=
optim
.
optim
.
state
[
p
]
torch_state
=
torch_optim
.
state
[
torch_p
]
if
p
.
storage
().
size
()
==
0
:
assert
len
(
state
)
==
0
check_state
(
state
,
torch_state
)
from
tests.test_tensor.common_utils
import
debug_print
from
colossalai.gemini.chunk
import
search_chunk_configuration
,
ChunkManager
def
check_state_dict
(
state_dict
,
torch_state_dict
):
for
(
k1
,
s1
),
(
k2
,
s2
)
in
zip
(
state_dict
[
'state'
].
items
(),
torch_state_dict
[
'state'
].
items
()):
assert
k1
==
k2
check_state
(
s1
,
s2
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'only_rank_0'
,
[
False
,
True
])
def
run_zero_optim_state_dict
(
use_chunk
,
use_zero
,
placement_policy
,
only_rank_0
):
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_zero_optim_state_dict
(
placement_policy
,
keep_gathered
):
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
pg
=
ProcessGroup
()
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
set_seed
(
451
)
torch_model
=
model_builder
()
# get a different model
for
p
in
torch_model
.
parameters
():
p
.
grad
=
torch
.
rand_like
(
p
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
torch_optim
.
step
()
torch_state_dict
=
torch_optim
.
state_dict
()
optim
.
load_state_dict
(
torch_state_dict
)
check_load_state_dict
(
optim
,
torch_optim
)
state_dict
=
optim
.
state_dict
(
only_rank_0
)
if
not
only_rank_0
or
pg
.
rank
()
==
0
:
check_state_dict
(
state_dict
,
torch_state_dict
)
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
())
optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
32
)
# initialize the link between chunk16 and chunk32
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
model
.
train
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
break
optim
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
optim
.
backward
(
loss
)
optim
.
step
()
optim_state_dict
=
optim
.
state_dict
()
optim
.
load_state_dict
(
optim_state_dict
)
new_state
=
optim
.
state_dict
()[
'state'
]
org_state
=
optim_state_dict
[
'state'
]
for
k
,
v
in
org_state
.
items
():
w
=
new_state
[
k
]
for
n
,
m
in
v
.
items
():
if
isinstance
(
m
,
torch
.
Tensor
):
o
=
w
[
n
]
if
m
.
device
!=
o
.
device
:
o
=
o
.
to
(
m
.
device
)
assert
torch
.
equal
(
m
,
o
)
else
:
assert
m
==
w
[
n
]
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run
_zero_optim_state_dict
()
exam
_zero_optim_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_zero_optim
_state_dict
(
world_size
):
def
test_zero_optim
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_optim
_state_dict
(
2
)
test_zero_optim
(
1
)
tests/test_tensor/test_chunk.py
deleted
100644 → 0
View file @
b1be5b88
import
torch
import
colossalai
import
pytest
import
torch.multiprocessing
as
mp
from
typing
import
List
from
functools
import
partial
from
colossalai.gemini
import
ChunkManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
parameterize
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
def
check_has_params
(
params
:
List
[
torch
.
Tensor
],
has_tensors
:
List
[
bool
]):
for
p
,
has_tensor
in
zip
(
params
,
has_tensors
):
if
has_tensor
:
assert
p
.
storage
().
size
()
>
0
assert
p
.
device
.
type
==
'cuda'
else
:
assert
p
.
storage
().
size
()
==
0
# HAS_TENSORS[use_chunk][use_zero]
HAS_TENSORS
=
{
True
:
{
True
:
[[
True
,
True
,
False
],
[
False
,
False
,
True
]],
False
:
[[
True
,
True
,
True
],
[
True
,
True
,
True
]]
},
False
:
{
True
:
[[
True
,
False
,
True
],
[
False
,
True
,
False
]],
False
:
[[
True
,
True
,
True
],
[
True
,
True
,
True
]]
}
}
TOTAL_MEM
=
{
True
:
{
True
:
[
512
,
512
],
False
:
[
1024
,
1024
]},
False
:
{
True
:
[
512
,
256
],
False
:
[
768
,
768
]}}
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
def
run_chunk_zero
(
use_chunk
,
use_zero
):
pg
=
ColoProcessGroup
()
rank
=
pg
.
rank
()
if
rank
==
0
:
print
(
f
'use_chunk=
{
use_chunk
}
, use_zero=
{
use_zero
}
'
)
params
=
[
torch
.
rand
(
8
,
8
)
for
_
in
range
(
3
)]
chunk_size
=
128
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
)
chunk_manager
.
create_group
(
'param'
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
for
p
in
params
:
chunk_manager
.
append_tensor
(
p
,
'param'
)
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
]
chunks
=
chunk_manager
.
get_chunks
(
params
)
for
chunk
in
chunks
:
chunk_manager
.
access_chunk
(
chunk
)
check_has_params
(
params
,
[
True
,
True
,
True
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
False
][
rank
]
for
chunk
in
chunks
:
chunk_manager
.
release_chunk
(
chunk
)
check_has_params
(
params
,
HAS_TENSORS
[
use_chunk
][
use_zero
][
rank
])
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
],
chunk_manager
.
total_mem
[
'cuda'
]
for
chunk
in
chunks
:
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
TOTAL_MEM
[
use_chunk
][
use_zero
][
rank
],
chunk_manager
.
total_mem
[
'cuda'
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_chunk_zero
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_chunk_mapping
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_chunk_mapping
(
2
)
tests/test_tensor/test_
zero_optim
.py
→
tests/test_tensor/test_
tp_with_zero
.py
View file @
b28991dd
...
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
colossalai.gemini
.chunk
import
ChunkManager
,
search_chunk_configuration
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
...
@@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute
from
tests.test_tensor.model.test_gpt2
import
init_megatron_spec
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_parameters
(),
torch_model
.
named_parameters
()):
if
p
.
storage
().
size
()
>
0
:
assert
p
.
dtype
==
torch
.
float16
assert
tensor_shard_equal
(
tp
.
to
(
dtype
=
p
.
dtype
,
device
=
p
.
device
),
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
f
'
{
tp
}
vs
{
p
}
\n
{
n
}
:
\n\t
{
tp
.
shape
}
vs
{
p
.
shape
}
'
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
(
n
,
p
),
(
tn
,
tp
)
in
zip
(
model
.
named_parameters
(),
torch_model
.
named_parameters
()):
if
p
.
grad
is
not
None
:
assert
tensor_shard_equal
(
tp
.
grad
.
to
(
dtype
=
p
.
grad
.
dtype
,
device
=
p
.
grad
.
device
),
p
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
\
f
'
{
tp
.
grad
}
vs
{
p
.
grad
}
\n
{
n
}
:
\n\t
{
tp
.
grad
.
shape
}
vs
{
p
.
grad
.
shape
}
in
{
pg
.
rank
()
}
'
for
key
,
value
in
torch_dict
.
items
():
# key is 'module.model.PARAMETER', so we truncate it
key
=
key
[
7
:]
if
key
==
'model.lm_head.weight'
:
continue
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert
tensor_shard_equal
(
value
,
temp_zero_value
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
\
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
...
...
@@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p
.
set_tensor_spec
(
*
spec
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
def
run_gpt
(
use_chunk
,
use_zero
,
placement_policy
,
tp_init_spec_func
=
None
):
def
run_gpt
(
placement_policy
,
tp_init_spec_func
=
None
):
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
...
@@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if
tp_init_spec_func
:
tp_init_spec_func
(
model
,
pg
)
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
dp_world_size
=
pg
.
dp_world_size
()
config_dict
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
dp_world_size
][
'chunk_size'
]
=
5000
config_dict
[
dp_world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
1
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
...
...
@@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
print
(
chunk_manager
)
check_param
_equal
(
model
,
torch_model
,
pg
)
check_param
(
model
,
torch_model
,
pg
)
model
.
eval
()
torch_model
.
eval
()
...
...
@@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if
i
>
2
:
break
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
run_fwd_bwd
(
model
,
criterion
,
optim
,
input_ids_colo
,
attn_mask
)
zero_
logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_
optim
,
input_ids_colo
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
t
ensor_equal
(
logits
,
torch_logits
)
check_grad_equal
(
model
,
torch_model
,
pg
)
optim
.
step
()
assert
t
orch
.
allclose
(
zero_
logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
zero_
optim
.
step
()
torch_optim
.
step
()
check_param
_equal
(
model
,
torch_model
,
pg
)
check_param
(
model
,
torch_model
,
pg
)
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment