Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a39a5c66
Unverified
Commit
a39a5c66
authored
Sep 04, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 04, 2023
Browse files
Merge branch 'main' into feature/shardformer
parents
e79b1e80
aaeb520c
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
283 additions
and
450 deletions
+283
-450
tests/test_utils/test_norm_gradient_clipping.py
tests/test_utils/test_norm_gradient_clipping.py
+1
-0
tests/test_zero/test_gemini/test_chunk_mgrv2.py
tests/test_zero/test_gemini/test_chunk_mgrv2.py
+5
-5
tests/test_zero/test_gemini/test_chunkv2.py
tests/test_zero/test_gemini/test_chunkv2.py
+2
-2
tests/test_zero/test_gemini/test_fwd_bwd.py
tests/test_zero/test_gemini/test_fwd_bwd.py
+32
-73
tests/test_zero/test_gemini/test_gemini_use_rmt.py
tests/test_zero/test_gemini/test_gemini_use_rmt.py
+12
-12
tests/test_zero/test_gemini/test_get_torch_model.py
tests/test_zero/test_gemini/test_get_torch_model.py
+0
-52
tests/test_zero/test_gemini/test_grad_clip.py
tests/test_zero/test_gemini/test_grad_clip.py
+39
-16
tests/test_zero/test_gemini/test_inference.py
tests/test_zero/test_gemini/test_inference.py
+34
-30
tests/test_zero/test_gemini/test_optim.py
tests/test_zero/test_gemini/test_optim.py
+54
-27
tests/test_zero/test_gemini/test_runtime_mem_tracer.py
tests/test_zero/test_gemini/test_runtime_mem_tracer.py
+3
-3
tests/test_zero/test_gemini/test_search.py
tests/test_zero/test_gemini/test_search.py
+3
-55
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
+56
-24
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
+0
-56
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
+31
-20
tests/test_zero/test_low_level/test_grad_acc.py
tests/test_zero/test_low_level/test_grad_acc.py
+9
-19
tests/test_zero/test_low_level/test_zero_ckpt.py
tests/test_zero/test_low_level/test_zero_ckpt.py
+1
-1
tests/test_zero/test_low_level/test_zero_init.py
tests/test_zero/test_low_level/test_zero_init.py
+0
-55
tests/test_zero/test_low_level/test_zero_tp.py
tests/test_zero/test_low_level/test_zero_tp.py
+1
-0
No files found.
tests/test_utils/test_norm_gradient_clipping.py
View file @
a39a5c66
...
@@ -66,6 +66,7 @@ def run_dist(rank, world_size, port):
...
@@ -66,6 +66,7 @@ def run_dist(rank, world_size, port):
run_grad_clip_norm
(
world_size
=
world_size
)
run_grad_clip_norm
(
world_size
=
world_size
)
@
pytest
.
mark
.
skip
(
"this need to be updated"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
...
tests/test_zero/test_gemini/test_chunk_mgrv2.py
View file @
a39a5c66
import
pytest
import
pytest
import
torch
import
torch
from
torch.distributed.distributed_c10d
import
_get_default_group
import
colossalai
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ProcessGroup
from
colossalai.tensor
import
ColoTensor
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.zero.gemini.chunk
import
ChunkManager
from
colossalai.zero.gemini.chunk
import
ChunkManager
from
tests.test_tensor.common_utils
import
debug_print
from
tests.test_tensor.common_utils
import
debug_print
...
@@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
...
@@ -15,19 +16,18 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
def
exam_chunk_memory
(
keep_gathered
,
pin_memory
):
def
exam_chunk_memory
(
keep_gathered
,
pin_memory
):
pg
=
ProcessGroup
()
debug_print
([
0
],
"keep_gathered: {}, pin_memory: {}"
.
format
(
keep_gathered
,
pin_memory
))
debug_print
([
0
],
"keep_gathered: {}, pin_memory: {}"
.
format
(
keep_gathered
,
pin_memory
))
params
=
[
ColoTensor
(
torch
.
rand
(
8
,
8
)
,
spec
=
ColoTensorSpec
(
pg
)
)
for
_
in
range
(
3
)]
params
=
[
ColoTensor
(
torch
.
rand
(
8
,
8
))
for
_
in
range
(
3
)]
config
=
{
2
:
dict
(
chunk_size
=
128
,
keep_gathered
=
keep_gathered
)}
config
=
{
2
:
dict
(
chunk_size
=
128
,
keep_gathered
=
keep_gathered
)}
chunk_manager
=
ChunkManager
(
config
)
chunk_manager
=
ChunkManager
(
config
)
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
0
process_group
=
_get_default_group
()
for
p
in
params
:
for
p
in
params
:
chunk_manager
.
register_tensor
(
p
,
'param'
,
2
,
pin_memory
=
pin_memory
)
chunk_manager
.
register_tensor
(
p
,
'param'
,
2
,
process_group
,
pin_memory
=
pin_memory
)
chunk_manager
.
close_all_groups
()
chunk_manager
.
close_all_groups
()
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cpu'
]
==
CPU_MEM
[
keep_gathered
][
pin_memory
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
keep_gathered
]
assert
chunk_manager
.
total_mem
[
'cuda'
]
==
CUDA_MEM_0
[
keep_gathered
]
...
...
tests/test_zero/test_gemini/test_chunkv2.py
View file @
a39a5c66
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed.distributed_c10d
import
_get_default_group
import
colossalai
import
colossalai
from
colossalai.tensor
import
ColoParameter
from
colossalai.tensor
import
ColoParameter
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.gemini
import
TensorState
from
colossalai.zero.gemini
import
TensorState
...
@@ -36,7 +36,7 @@ def check_equal(param, param_cp):
...
@@ -36,7 +36,7 @@ def check_equal(param, param_cp):
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
@
parameterize
(
'pin_memory'
,
[
True
,
False
])
def
exam_chunk_basic
(
init_device
,
keep_gathered
,
pin_memory
):
def
exam_chunk_basic
(
init_device
,
keep_gathered
,
pin_memory
):
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ColoProcessG
roup
()
pg
=
_get_default_g
roup
()
my_chunk
=
Chunk
(
chunk_size
=
1024
,
my_chunk
=
Chunk
(
chunk_size
=
1024
,
process_group
=
pg
,
process_group
=
pg
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
...
tests/test_zero/test_gemini/test_fwd_bwd.py
View file @
a39a5c66
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
import
colossalai
import
colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
,
ZeroOptimizer
from
colossalai.zero
import
GeminiDDP
,
GeminiOptimizer
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd
,
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
from
tests.test_tensor.common_utils
import
set_seed
PLACEMENT_CONFIGS
=
[
def
check_grad
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
},
# zero2
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
1.0
},
# zero3
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.5
},
# zero3-half
{
'placement_policy'
:
'auto'
}
]
def
check_grad
(
model
:
GeminiDDP
,
torch_model
:
torch
.
nn
.
Module
):
chunk_manager
=
model
.
chunk_manager
chunk_manager
=
model
.
chunk_manager
param_list
=
[
p
for
p
in
model
.
parameters
()]
param_list
=
[
p
for
p
in
model
.
parameters
()]
chunk_list
=
chunk_manager
.
get_chunks
(
param_list
)
chunk_list
=
chunk_manager
.
get_chunks
(
param_list
)
...
@@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -28,12 +45,12 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close
(
p0
,
p1
.
grad
,
rtol
=
1e-3
,
atol
=
5e-5
)
assert_close
(
p0
,
p1
.
grad
,
rtol
=
1e-3
,
atol
=
5e-5
)
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'keep_gather'
,
[
False
,
True
])
@
parameterize
(
'keep_gather'
,
[
False
,
True
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
,
'albert'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
,
'albert'
])
@
parameterize
(
'use_grad_checkpoint'
,
[
False
,
True
])
@
parameterize
(
'use_grad_checkpoint'
,
[
False
,
True
])
def
exam_gpt_fwd_bwd
(
def
exam_gpt_fwd_bwd
(
placement_
policy
,
placement_
config
,
keep_gather
,
keep_gather
,
model_name
:
str
,
model_name
:
str
,
use_grad_checkpoint
:
bool
=
False
,
use_grad_checkpoint
:
bool
=
False
,
...
@@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd(
...
@@ -43,8 +60,7 @@ def exam_gpt_fwd_bwd(
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
set_seed
(
42
)
set_seed
(
42
)
with
ColoInitContext
(
device
=
init_device
):
model
=
model_builder
(
use_grad_checkpoint
)
model
=
model_builder
(
use_grad_checkpoint
)
set_seed
(
42
)
set_seed
(
42
)
torch_model
=
model_builder
(
use_grad_checkpoint
).
cuda
()
torch_model
=
model_builder
(
use_grad_checkpoint
).
cuda
()
...
@@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd(
...
@@ -55,19 +71,17 @@ def exam_gpt_fwd_bwd(
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
model
=
GeminiDDP
(
model
,
config_dict
,
init_device
,
pin_memory
=
True
,
**
placement_config
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
Zero
Optimizer
(
optimizer
,
model
,
initial_scale
=
1
)
zero_optim
=
Gemini
Optimizer
(
optimizer
,
model
,
initial_scale
=
1
)
pg
=
ProcessGroup
()
rank
=
dist
.
get_rank
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
()
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
rank
]
)
set_seed
(
pg
.
dp_local_
rank
()
)
set_seed
(
rank
)
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
# you can only test a single fwd + bwd.
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
...
@@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd(
...
@@ -89,65 +103,10 @@ def exam_gpt_fwd_bwd(
check_grad
(
model
,
torch_model
)
check_grad
(
model
,
torch_model
)
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'keep_gather'
,
[
False
,
True
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
,
'albert'
])
@
parameterize
(
'scatter_after_inference'
,
[
False
,
True
])
def
exam_gpt_inference
(
placement_policy
,
keep_gather
,
model_name
:
str
,
scatter_after_inference
:
bool
=
False
,
):
init_device
=
get_current_device
()
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
set_seed
(
42
)
with
ColoInitContext
(
device
=
init_device
):
model
=
model_builder
()
set_seed
(
42
)
torch_model
=
model_builder
().
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
,
scatter_after_inference
=
scatter_after_inference
)
pg
=
ProcessGroup
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
set_seed
(
pg
.
dp_local_rank
())
model
.
eval
()
torch_model
.
eval
()
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if
i
>
0
:
break
with
torch
.
no_grad
():
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
torch_loss
=
run_fwd
(
torch_model
,
input_ids
,
label
,
criterion
)
loss
=
run_fwd
(
model
,
input_ids
,
label
,
criterion
)
assert
torch
.
equal
(
torch_loss
,
loss
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_gpt_fwd_bwd
()
exam_gpt_fwd_bwd
()
exam_gpt_inference
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero/test_gemini/test_gemini_use_rmt.py
View file @
a39a5c66
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
colossalai
import
colossalai
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
from
colossalai.zero
import
GeminiDDP
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
colossalai.zero.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
colossalai.zero.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
...
@@ -24,8 +23,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
'cpu'
):
model
=
model_builder
(
use_grad_checkpoint
).
cuda
()
model
=
model_builder
(
use_grad_checkpoint
)
print
(
f
'model_name
{
model_name
}
'
)
print
(
f
'model_name
{
model_name
}
'
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
...
@@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
...
@@ -59,12 +57,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
model
=
GeminiDDP
(
model
,
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
chunk_config_dict
=
config_dict
,
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
placement_policy
=
placement_policy
,
pin_memory
=
True
,
memstats
=
memstats
)
pg
=
ProcessGroup
()
set_seed
(
dist
.
get_rank
())
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
# you can only test a single fwd + bwd.
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
...
@@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
...
@@ -76,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
set_seed
(
42
)
set_seed
(
42
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
model
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
model
)
gemini_non_model_data
=
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
gemini_non_model_data
=
model
.
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
# print('gemini non model data:', gemini_non_model_data)
# print('gemini non model data:', gemini_non_model_data)
...
@@ -90,6 +89,7 @@ def run_dist(rank, world_size, port):
...
@@ -90,6 +89,7 @@ def run_dist(rank, world_size, port):
run_gemini_use_rmt
()
run_gemini_use_rmt
()
@
pytest
.
mark
.
skip
(
"this is not used"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
...
tests/test_zero/test_gemini/test_get_torch_model.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
colossalai
from
colossalai.tensor
import
ColoParameter
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
GeminiDDP
from
colossalai.zero.gemini.utils
import
get_static_torch_model
from
tests.components_to_test.registry
import
non_distributed_component_funcs
@
parameterize
(
'model_name'
,
[
'hanging_param_model'
,
'resnet18'
,
'gpt2'
])
def
run_convert_torch_module
(
model_name
:
str
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
with
ColoInitContext
(
device
=
torch
.
device
(
"cpu"
)):
model
=
model_builder
(
checkpoint
=
False
)
model
=
GeminiDDP
(
model
,
device
=
get_current_device
(),
placement_policy
=
'auto'
,
pin_memory
=
True
)
pytorch_model
=
get_static_torch_model
(
model
,
only_rank_0
=
False
)
for
n
,
p
in
pytorch_model
.
named_parameters
():
assert
type
(
p
)
==
torch
.
nn
.
Parameter
,
f
"type error:
{
n
}
is a
{
type
(
p
)
}
"
# get the static model should not change the original model
for
n
,
p
in
model
.
named_parameters
():
assert
isinstance
(
p
,
ColoParameter
)
for
(
pn
,
pm
),
(
cn
,
cm
)
in
zip
(
pytorch_model
.
named_modules
(),
model
.
named_modules
()):
assert
pn
==
cn
assert
id
(
pm
)
!=
id
(
cm
)
for
pp
,
cp
in
zip
(
pm
.
parameters
(
recurse
=
False
),
cm
.
parameters
(
recurse
=
False
)):
assert
id
(
pp
)
!=
id
(
cp
)
assert
pp
.
shape
==
cp
.
shape
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_convert_torch_module
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_convert_torch_module
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_convert_torch_module
(
2
)
tests/test_zero/test_gemini/test_grad_clip.py
View file @
a39a5c66
...
@@ -8,16 +8,38 @@ import colossalai
...
@@ -8,16 +8,38 @@ import colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
GeminiDDP
,
GeminiOptimizer
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
,
ZeroOptimizer
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
set_seed
from
tests.test_tensor.common_utils
import
set_seed
PLACEMENT_CONFIGS
=
[
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
0.0
,
'offload_param_frac'
:
0.0
},
# zero2
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
1.0
,
'offload_param_frac'
:
0.0
},
# zero2-offload
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
0.5
,
'offload_param_frac'
:
0.0
},
# zero2-offload-half
{
'placement_policy'
:
'auto'
}
]
def
check_param
(
model
:
GeminiDDP
,
torch_model
:
torch
.
nn
.
Module
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
torch_dict
=
torch_model
.
state_dict
()
...
@@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -30,9 +52,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
4e-3
)
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
4e-3
)
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'model_name'
,
[
'gpt2'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
])
def
exam_grad_clipping
(
placement_
policy
,
model_name
:
str
):
def
exam_grad_clipping
(
placement_
config
,
model_name
:
str
):
set_seed
(
1912
)
set_seed
(
1912
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
...
@@ -43,9 +65,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
init_dev
=
get_current_device
()
model
=
model_builder
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
...
@@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str):
...
@@ -54,16 +74,19 @@ def exam_grad_clipping(placement_policy, model_name: str):
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
if
placement_config
[
'
placement_policy
'
]
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
init_device
=
torch
.
device
(
'cpu'
)
else
:
else
:
init_device
=
None
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
GeminiDDP
(
model
,
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
chunk_config_dict
=
config_dict
,
chunk_init_device
=
init_device
,
pin_memory
=
True
,
**
placement_config
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
Zero
Optimizer
(
optimizer
,
model
,
initial_scale
=
32
,
clipping_norm
=
1.0
)
zero_optim
=
Gemini
Optimizer
(
optimizer
,
model
,
initial_scale
=
32
,
clipping_norm
=
1.0
)
model
.
train
()
model
.
train
()
torch_model
.
train
()
torch_model
.
train
()
...
...
tests/test_zero/test_gemini/test_inference.py
View file @
a39a5c66
...
@@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp
...
@@ -11,15 +11,32 @@ from colossalai.amp import convert_to_apex_amp
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
,
ZeroOptimizer
,
post_process_colo_init_ctx
,
zero_model_wrapper
from
colossalai.zero
import
GeminiDDP
,
GeminiOptimizer
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
init_chunk_manager
,
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
from
tests.test_tensor.common_utils
import
set_seed
PLACEMENT_CONFIGS
=
[
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
},
# zero2
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
1.0
},
# zero3
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.5
},
# zero3-half
{
'placement_policy'
:
'auto'
}
]
def
check_param
(
model
:
GeminiDDP
,
torch_model
:
torch
.
nn
.
Module
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
torch_dict
=
torch_model
.
state_dict
()
...
@@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -32,35 +49,24 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
4e-3
)
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
4e-3
)
def
multi_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_
policy
:
str
):
def
multi_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_
config
:
dict
):
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
model
=
GeminiDDP
(
model
,
config_dict
,
pin_memory
=
True
,
**
placement_config
)
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
return
model
return
model
def
single_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_policy
:
str
):
def
single_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_config
:
dict
):
gemini_config
=
dict
(
model
=
GeminiDDP
(
model
,
chunk_init_device
=
get_current_device
(),
pin_memory
=
True
,
**
placement_config
)
device
=
get_current_device
(),
placement_policy
=
placement_policy
,
pin_memory
=
True
,
)
model
=
zero_model_wrapper
(
model
=
model
,
zero_stage
=
3
,
gemini_config
=
gemini_config
)
return
model
return
model
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'model_name'
,
[
'gpt2'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
])
@
parameterize
(
'model_init_func'
,
[
single_chunk_init
,
multi_chunk_init
])
@
parameterize
(
'model_init_func'
,
[
single_chunk_init
,
multi_chunk_init
])
def
exam_inference
(
placement_
policy
:
str
,
model_name
:
str
,
model_init_func
:
Callable
):
def
exam_inference
(
placement_
config
:
dict
,
model_name
:
str
,
model_init_func
:
Callable
):
set_seed
(
19360226
)
set_seed
(
19360226
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
...
@@ -70,17 +76,15 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
init_dev
=
get_current_device
()
init_dev
=
get_current_device
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
().
to
(
init_dev
)
model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
model
=
model_init_func
(
model
,
placement_
policy
)
model
=
model_init_func
(
model
,
placement_
config
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
Zero
Optimizer
(
optimizer
,
model
,
initial_scale
=
128
)
zero_optim
=
Gemini
Optimizer
(
optimizer
,
model
,
initial_scale
=
128
)
model
.
eval
()
model
.
eval
()
torch_model
.
eval
()
torch_model
.
eval
()
...
@@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
...
@@ -95,7 +99,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call
torch_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert_close
(
torch_loss
,
loss
)
assert_close
(
torch_loss
,
loss
,
rtol
=
1e-5
,
atol
=
1e-5
)
zero_optim
.
step
()
zero_optim
.
step
()
torch_optim
.
step
()
torch_optim
.
step
()
check_param
(
model
,
torch_model
)
check_param
(
model
,
torch_model
)
...
...
tests/test_zero/test_gemini/test_optim.py
View file @
a39a5c66
...
@@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp
...
@@ -9,12 +9,46 @@ from colossalai.amp import convert_to_apex_amp
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
,
ZeroOptimizer
,
post_process_colo_init_ctx
from
colossalai.zero
import
GeminiDDP
,
GeminiOptimizer
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
init_chunk_manager
,
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
from
tests.test_tensor.common_utils
import
set_seed
PLACEMENT_CONFIGS
=
[
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
0.0
},
# zero2
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
1.0
},
# zero2-offload
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
0.5
},
# zero2-offload-half
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
1.0
},
# zero3
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.5
},
# zero3-half
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
1.0
,
'offload_optim_frac'
:
1.0
,
'offload_param_frac'
:
1.0
},
# zero3-offload-all
{
'placement_policy'
:
'auto'
}
]
# this model is large enough to slice to chunks
# this model is large enough to slice to chunks
TEST_MODELS
=
[
'gpt2'
]
TEST_MODELS
=
[
'gpt2'
]
...
@@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [
...
@@ -29,7 +63,7 @@ BF16_IGNORED_KEYS = [
]
]
def
check_param
(
model
:
Zero
DDP
,
torch_model
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
):
def
check_param
(
model
:
Gemini
DDP
,
torch_model
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
,
dtype
=
dtype
)
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
,
dtype
=
dtype
)
torch_dict
=
torch_model
.
state_dict
()
torch_dict
=
torch_model
.
state_dict
()
...
@@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype
...
@@ -51,10 +85,10 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype
msg
=
lambda
s
:
s
+
f
'
\n
{
key
}
\n
{
temp_zero_value
.
dtype
}
'
)
msg
=
lambda
s
:
s
+
f
'
\n
{
key
}
\n
{
temp_zero_value
.
dtype
}
'
)
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'model_name'
,
TEST_MODELS
)
@
parameterize
(
'model_name'
,
TEST_MODELS
)
@
parameterize
(
'mixed_precision'
,
[
torch
.
half
,
torch
.
bfloat16
])
@
parameterize
(
'mixed_precision'
,
[
torch
.
half
,
torch
.
bfloat16
])
def
exam_model_step
(
placement_
policy
,
model_name
:
str
,
mixed_precision
:
torch
.
dtype
):
def
exam_model_step
(
placement_
config
,
model_name
:
str
,
mixed_precision
:
torch
.
dtype
):
set_seed
(
42
)
set_seed
(
42
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
...
@@ -65,9 +99,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
init_dev
=
get_current_device
()
model
=
model_builder
().
cuda
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
...
@@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
...
@@ -76,16 +108,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
model
=
GeminiDDP
(
model
,
config_dict
,
**
placement_config
,
mixed_precision
=
mixed_precision
)
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
,
mixed_precision
=
mixed_precision
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
Zero
Optimizer
(
optimizer
,
model
,
initial_scale
=
128
)
zero_optim
=
Gemini
Optimizer
(
optimizer
,
model
,
initial_scale
=
128
)
model
.
eval
()
model
.
eval
()
torch_model
.
eval
()
torch_model
.
eval
()
...
@@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
...
@@ -109,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
check_param
(
model
,
torch_model
,
mixed_precision
)
check_param
(
model
,
torch_model
,
mixed_precision
)
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'model_name'
,
EXAMPLE_MODELS
)
@
parameterize
(
'model_name'
,
EXAMPLE_MODELS
)
@
parameterize
(
'mixed_precision'
,
[
torch
.
half
,
torch
.
bfloat16
])
@
parameterize
(
'mixed_precision'
,
[
torch
.
half
,
torch
.
bfloat16
])
def
exam_tiny_example
(
placement_
policy
,
model_name
:
str
,
mixed_precision
:
torch
.
dtype
):
def
exam_tiny_example
(
placement_
config
,
model_name
:
str
,
mixed_precision
:
torch
.
dtype
):
set_seed
(
2008
)
set_seed
(
2008
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
...
@@ -123,18 +149,19 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
init_dev
=
get_current_device
()
model
=
model_builder
().
cuda
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_m
=
1
)
model
=
GeminiDDP
(
model
,
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
chunk_init_device
=
get_current_device
(),
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
,
mixed_precision
=
mixed_precision
)
search_range_m
=
1
,
pin_memory
=
True
,
mixed_precision
=
mixed_precision
,
**
placement_config
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
Zero
Optimizer
(
optimizer
,
model
,
initial_scale
=
2
)
zero_optim
=
Gemini
Optimizer
(
optimizer
,
model
,
initial_scale
=
2
)
model
.
eval
()
model
.
eval
()
torch_model
.
eval
()
torch_model
.
eval
()
...
...
tests/test_zero/test_gemini/test_runtime_mem_tracer.py
View file @
a39a5c66
from
copy
import
deepcopy
from
copy
import
deepcopy
import
numpy
as
np
import
numpy
as
np
import
pytest
import
torch
import
torch
from
colossalai.testing
import
clear_cache_before_run
from
colossalai.testing
import
clear_cache_before_run
from
colossalai.zero
import
ColoInitContext
from
colossalai.zero.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
colossalai.zero.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
@
pytest
.
mark
.
skip
(
"this is not used"
)
@
clear_cache_before_run
()
@
clear_cache_before_run
()
def
test_runtime_mem_tracer
():
def
test_runtime_mem_tracer
():
test_models
=
[
'gpt2'
,
'bert'
,
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
,
'albert'
]
test_models
=
[
'gpt2'
,
'bert'
,
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
,
'albert'
]
...
@@ -18,8 +19,7 @@ def test_runtime_mem_tracer():
...
@@ -18,8 +19,7 @@ def test_runtime_mem_tracer():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
'cpu'
):
model
=
model_builder
(
checkpoint
=
False
).
cuda
()
model
=
model_builder
(
checkpoint
=
False
)
model_bk
=
deepcopy
(
model
)
model_bk
=
deepcopy
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
...
...
tests/test_zero/test_gemini/test_search.py
View file @
a39a5c66
...
@@ -2,33 +2,20 @@ import pytest
...
@@ -2,33 +2,20 @@ import pytest
import
torch
import
torch
import
colossalai
import
colossalai
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero
import
ColoInitContext
from
colossalai.zero.gemini.chunk
import
init_chunk_manager
,
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
init_chunk_manager
,
search_chunk_configuration
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_process_group
(
pg
)
p
.
set_tensor_spec
(
*
tensor_spec
)
def
exam_search_chunk_size
():
def
exam_search_chunk_size
():
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg_tp
=
ProcessGroup
(
tp_degree
=
world_size
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# make sure torch_model and model has the same parameter values
# make sure torch_model and model has the same parameter values
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
init_1d_row_spec
(
model
,
pg_tp
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_range_m
=
1
,
search_interval
=
16
,
search_interval
=
16
,
...
@@ -37,57 +24,19 @@ def exam_search_chunk_size():
...
@@ -37,57 +24,19 @@ def exam_search_chunk_size():
for
key
in
config_dict
:
for
key
in
config_dict
:
chunk_size
=
config_dict
[
key
][
'chunk_size'
]
chunk_size
=
config_dict
[
key
][
'chunk_size'
]
if
world_size
==
1
:
if
world_size
==
1
or
True
:
assert
chunk_size
==
31616
assert
chunk_size
==
31616
else
:
else
:
assert
chunk_size
==
1024
assert
chunk_size
==
1024
def
exam_search_strict_ddp
():
world_size
=
torch
.
distributed
.
get_world_size
()
default_shard_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
default_shard_spec
=
ShardSpec
([
-
1
],
[
world_size
])
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
# get the chunk configuration over replicated models
with
ColoInitContext
(
device
=
get_current_device
()):
ddp_model
=
model_builder
()
re_dict
,
re_total
,
re_wasted
=
search_chunk_configuration
(
ddp_model
,
search_range_m
=
1
,
search_interval
=
16
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
False
)
# get the chunk configuration over sharded ddp models
with
ColoInitContext
(
device
=
get_current_device
(),
default_pg
=
default_shard_pg
,
default_dist_spec
=
default_shard_spec
):
sharded_ddp_model
=
model_builder
()
sh_dict
,
sh_total
,
sh_wasted
=
search_chunk_configuration
(
sharded_ddp_model
,
search_range_m
=
1
,
search_interval
=
16
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
True
)
assert
re_dict
==
sh_dict
for
key
in
re_dict
:
assert
re_dict
[
key
]
==
sh_dict
[
key
]
assert
re_total
==
sh_total
assert
re_wasted
==
sh_wasted
def
exam_chunk_manager
():
def
exam_chunk_manager
():
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
default_shard_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
default_shard_spec
=
ShardSpec
([
-
1
],
[
world_size
])
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
(),
default_pg
=
default_shard_pg
,
sharded_ddp_model
=
model_builder
()
default_dist_spec
=
default_shard_spec
):
sharded_ddp_model
=
model_builder
()
chunk_manager
=
init_chunk_manager
(
sharded_ddp_model
,
chunk_manager
=
init_chunk_manager
(
sharded_ddp_model
,
get_current_device
(),
get_current_device
(),
hidden_dim
=
16
,
hidden_dim
=
16
,
...
@@ -103,7 +52,6 @@ def exam_chunk_manager():
...
@@ -103,7 +52,6 @@ def exam_chunk_manager():
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_search_chunk_size
()
exam_search_chunk_size
()
exam_search_strict_ddp
()
exam_chunk_manager
()
exam_chunk_manager
()
...
...
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
View file @
a39a5c66
...
@@ -4,31 +4,46 @@ from torch.testing import assert_close
...
@@ -4,31 +4,46 @@ from torch.testing import assert_close
import
colossalai
import
colossalai
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
GeminiDDP
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
from
tests.test_tensor.common_utils
import
set_seed
PLACEMENT_CONFIGS
=
[
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
},
# zero2
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
1.0
},
# zero3
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.5
},
# zero3-half
{
'placement_policy'
:
'auto'
}
]
def
ignore_the_first_parameter
(
model
:
torch
.
nn
.
Module
):
def
ignore_the_first_parameter
(
model
:
torch
.
nn
.
Module
):
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
print
(
f
"parameter `
{
name
}
` is set ignored"
)
print
(
f
"parameter `
{
name
}
` is set ignored"
)
Zero
DDP
.
set_params_to_ignore
([
param
])
Gemini
DDP
.
set_params_to_ignore
([
param
])
return
return
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
def
exam_state_dict
(
placement_
policy
,
keep_gathered
,
model_name
:
str
):
def
exam_state_dict
(
placement_
config
,
keep_gathered
,
model_name
:
str
):
set_seed
(
431
)
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
torch_model
=
model_builder
()
torch_model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
...
@@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
...
@@ -38,9 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
chunk_manager
=
ChunkManager
(
config_dict
)
model
=
GeminiDDP
(
model
,
config_dict
,
**
placement_config
,
pin_memory
=
True
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
.
train
()
model
.
train
()
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
...
@@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
...
@@ -52,16 +65,15 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-5
)
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-5
)
@
parameterize
(
'placement_
policy'
,
[
'cuda'
,
'cpu'
,
'auto'
]
)
@
parameterize
(
'placement_
config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
def
exam_load_state_dict
(
placement_
policy
,
keep_gathered
,
model_name
:
str
):
def
exam_load_state_dict
(
placement_
config
,
keep_gathered
,
model_name
:
str
):
set_seed
(
431
)
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
set_seed
(
451
)
set_seed
(
451
)
torch_model
=
model_builder
()
# get a different model
torch_model
=
model_builder
()
# get a different model
...
@@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
...
@@ -71,13 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
if
placement_policy
!=
'cuda'
:
model
=
GeminiDDP
(
model
,
config_dict
,
**
placement_config
,
pin_memory
=
True
)
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
torch_dict
=
torch_model
.
state_dict
()
torch_dict
=
torch_model
.
state_dict
()
model
.
load_state_dict
(
torch_dict
,
strict
=
False
)
model
.
load_state_dict
(
torch_dict
,
strict
=
False
)
...
@@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
...
@@ -89,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-5
)
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-5
)
@
parameterize
(
'placement_config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
def
exam_state_dict_shard
(
placement_config
,
model_name
:
str
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model
=
model_builder
()
model_size
=
sum
(
p
.
numel
()
*
p
.
element_size
()
for
p
in
model
.
parameters
())
/
1024
**
2
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
model
=
GeminiDDP
(
model
,
config_dict
,
**
placement_config
)
model
.
train
()
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
accumulated_keys
=
set
()
# ensure number of shards > 1
for
shard
,
_
in
model
.
state_dict_shard
(
max_shard_size
=
(
model_size
/
3
),
only_rank_0
=
False
):
for
key
,
value
in
shard
.
items
():
assert
key
not
in
accumulated_keys
,
f
"key `
{
key
}
` is duplicated."
accumulated_keys
.
add
(
key
)
assert
key
in
zero_dict
,
f
"
{
key
}
not in ZeRO dictionary."
assert
torch
.
equal
(
value
,
zero_dict
[
key
]),
f
"
{
key
}
not equal."
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_state_dict
()
exam_state_dict
()
exam_load_state_dict
()
exam_load_state_dict
()
exam_state_dict_shard
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test.registry
import
non_distributed_component_funcs
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
def
exam_state_dict
(
placement_policy
,
model_name
:
str
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model_size
=
sum
(
p
.
numel
()
*
p
.
element_size
()
for
p
in
model
.
parameters
())
/
1024
**
2
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
model
.
train
()
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
accumulated_keys
=
set
()
# ensure number of shards > 1
for
shard
,
_
in
model
.
state_dict_shard
(
max_shard_size
=
(
model_size
/
3
),
only_rank_0
=
False
):
for
key
,
value
in
shard
.
items
():
assert
key
not
in
accumulated_keys
,
f
"key `
{
key
}
` is duplicated."
accumulated_keys
.
add
(
key
)
assert
key
in
zero_dict
,
f
"
{
key
}
not in ZeRO dictionary."
assert
torch
.
equal
(
value
,
zero_dict
[
key
]),
f
"
{
key
}
not equal."
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_zero_ddp_state_dict_shard
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
'__main__'
:
test_zero_ddp_state_dict_shard
(
1
)
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
View file @
a39a5c66
...
@@ -5,42 +5,53 @@ import torch.distributed as dist
...
@@ -5,42 +5,53 @@ import torch.distributed as dist
import
colossalai
import
colossalai
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
GeminiDDP
,
GeminiOptimizer
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
,
ZeroOptimizer
from
colossalai.zero.gemini.chunk
import
search_chunk_configuration
from
colossalai.zero.gemini.chunk
import
ChunkManager
,
search_chunk_configuration
from
colossalai.zero.gemini.gemini_mgr
import
GeminiManager
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
from
tests.test_tensor.common_utils
import
set_seed
PLACEMENT_CONFIGS
=
[
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
0.0
},
# zero2
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
1.0
},
# zero2-offload
{
'placement_policy'
:
'static'
,
'shard_param_frac'
:
0.0
,
'offload_optim_frac'
:
0.5
},
# zero2-offload-half
{
'placement_policy'
:
'auto'
}
]
@
parameterize
(
'placement_config'
,
PLACEMENT_CONFIGS
)
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
@
parameterize
(
'keep_gathered'
,
[
True
,
False
])
def
exam_zero_optim_state_dict
(
placement_
policy
,
keep_gathered
):
def
exam_zero_optim_state_dict
(
placement_
config
,
keep_gathered
):
set_seed
(
431
)
set_seed
(
431
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
set_seed
(
451
)
set_seed
(
451
)
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
if
placement_policy
!=
'cuda'
:
model
=
GeminiDDP
(
model
,
config_dict
,
**
placement_config
,
pin_memory
=
True
)
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
())
optimizer
=
HybridAdam
(
model
.
parameters
())
optim
=
Zero
Optimizer
(
optimizer
,
model
,
initial_scale
=
32
)
# initialize the link between chunk16 and chunk32
optim
=
Gemini
Optimizer
(
optimizer
,
model
,
initial_scale
=
32
)
# initialize the link between chunk16 and chunk32
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
model
.
train
()
model
.
train
()
...
...
tests/test_zero/test_low_level/test_grad_acc.py
View file @
a39a5c66
...
@@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc():
...
@@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc():
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
# zero-dp backward
# zero-dp backward
no_sync
=
number
==
0
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
())
with
conditional_context
(
zero1_optimizer
.
no_sync
(),
no_sync
):
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
())
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
())
with
conditional_context
(
zero2_optimizer
.
no_sync
(),
no_sync
):
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
())
if
check_flag
:
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
if
z2p
.
grad
is
not
None
:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert
torch
.
equal
(
z1p
.
grad
,
z2p
.
grad
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
...
@@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
...
@@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
assert
torch
.
equal
(
z1p
.
data
,
z2p
.
data
)
assert
torch
.
equal
(
z1p
.
data
,
z2p
.
data
)
def
exam_zero_1_grad_acc
():
def
exam_zero_1_grad_acc
(
sync
):
local_rank
=
torch
.
distributed
.
get_rank
()
local_rank
=
torch
.
distributed
.
get_rank
()
seed_all
(
2008
)
seed_all
(
2008
)
...
@@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
...
@@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
input_data1
=
torch
.
randn
(
32
,
128
).
cuda
()
input_data1
=
torch
.
randn
(
32
,
128
).
cuda
()
input_data2
=
torch
.
randn
(
32
,
128
).
cuda
()
input_data2
=
torch
.
randn
(
32
,
128
).
cuda
()
def
fwd_bwd_func
(
n
umber
,
cur_data
,
check_flag
):
def
fwd_bwd_func
(
n
o_sync
,
cur_data
,
check_flag
):
no_sync
=
number
==
0
# zero1 fwd and bwd
# zero1 fwd and bwd
with
conditional_context
(
zero_optimizer
.
no_sync
(),
no_sync
):
with
conditional_context
(
zero_optimizer
.
no_sync
(),
no_sync
):
zero_output
=
zero_model
(
cur_data
)
zero_output
=
zero_model
(
cur_data
)
...
@@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
...
@@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
for
(
n
,
p
),
z1p
in
zip
(
torch_model
.
named_parameters
(),
zero_model
.
parameters
()):
for
(
n
,
p
),
z1p
in
zip
(
torch_model
.
named_parameters
(),
zero_model
.
parameters
()):
assert
torch
.
equal
(
p
.
grad
,
z1p
.
grad
)
assert
torch
.
equal
(
p
.
grad
,
z1p
.
grad
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
sync
,
input_data1
,
sync
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
fwd_bwd_func
(
False
,
input_data2
,
False
)
zero_optimizer
.
step
()
zero_optimizer
.
step
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
torch_model
.
parameters
(),
1.0
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
torch_model
.
parameters
(),
1.0
)
...
@@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
...
@@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
exam_zero_1_grad_acc
()
exam_zero_1_grad_acc
(
sync
=
True
)
# gradient accumulation is not compatible with ZeRO-2
exam_zero_1_grad_acc
(
sync
=
False
)
#
exam_zero_1_2_grad_acc()
exam_zero_1_2_grad_acc
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero/test_low_level/test_zero_ckpt.py
View file @
a39a5c66
...
@@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
...
@@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
atol
=
4e-3
atol
=
4e-3
a
=
a
.
detach
().
to
(
dtype
)
a
=
a
.
detach
().
to
(
dtype
)
b
=
b
.
detach
().
to
(
dtype
)
b
=
b
.
detach
().
to
(
dtype
)
.
to
(
a
.
device
)
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
...
...
tests/test_zero/test_low_level/test_zero_init.py
deleted
100644 → 0
View file @
e79b1e80
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
colossalai
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
spawn
from
colossalai.utils
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
LowLevelZeroOptimizer
class
MlpModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MlpModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
128
,
256
)
self
.
linear2
=
nn
.
Linear
(
256
,
512
)
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
return
x
def
exam_zero_init
():
dp_2_tp_2_pg
=
ProcessGroup
(
dp_degree
=
2
,
tp_degree
=
2
)
model1
=
MlpModel
().
cuda
()
with
ColoInitContext
(
device
=
get_current_device
(),
default_pg
=
dp_2_tp_2_pg
):
model2
=
MlpModel
()
optimizer1
=
LowLevelZeroOptimizer
(
torch
.
optim
.
Adam
(
model1
.
parameters
(),
lr
=
1
))
optimizer2
=
LowLevelZeroOptimizer
(
torch
.
optim
.
Adam
(
model2
.
parameters
(),
lr
=
1
))
assert
optimizer1
.
_local_rank
==
optimizer2
.
_local_rank
assert
optimizer1
.
_world_size
==
optimizer2
.
_world_size
mp_group1
=
optimizer1
.
tp_pg
mp_group2
=
optimizer2
.
tp_pg
assert
dist
.
get_world_size
(
mp_group1
)
==
dist
.
get_world_size
(
mp_group2
)
assert
dist
.
get_rank
(
mp_group1
)
==
dist
.
get_rank
(
mp_group2
)
def
run_dist
(
rank
,
world_size
,
port
):
config_dict
=
dict
(
parallel
=
dict
(
data
=
2
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)))
colossalai
.
launch
(
config
=
config_dict
,
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
exam_zero_init
()
@
pytest
.
mark
.
dist
def
test_zero_init
():
spawn
(
run_dist
,
4
)
if
__name__
==
'__main__'
:
test_zero_init
()
tests/test_zero/test_low_level/test_zero_tp.py
View file @
a39a5c66
...
@@ -85,6 +85,7 @@ def run_dist(rank, world_size, port):
...
@@ -85,6 +85,7 @@ def run_dist(rank, world_size, port):
exam_zero_with_tp
()
exam_zero_with_tp
()
@
pytest
.
mark
.
skip
(
'this will be rewritten by shardformer'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_zero_with_tp
():
def
test_zero_with_tp
():
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment