Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0bb0b481
Commit
0bb0b481
authored
Jun 25, 2023
by
Baizhou Zhang
Browse files
[gemini] fix argument naming during chunk configuration searching
parent
b463651f
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
62 additions
and
64 deletions
+62
-64
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+7
-10
colossalai/zero/gemini/chunk/search_utils.py
colossalai/zero/gemini/chunk/search_utils.py
+13
-13
colossalai/zero/gemini/chunk/utils.py
colossalai/zero/gemini/chunk/utils.py
+7
-7
colossalai/zero/gemini/gemini_ddp.py
colossalai/zero/gemini/gemini_ddp.py
+8
-8
tests/test_auto_parallel/test_offload/test_perf.py
tests/test_auto_parallel/test_offload/test_perf.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
...allel/test_tensor_shard/test_compatibility_with_gemini.py
+1
-1
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+1
-1
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+1
-1
tests/test_zero/test_gemini/test_fwd_bwd.py
tests/test_zero/test_gemini/test_fwd_bwd.py
+2
-2
tests/test_zero/test_gemini/test_gemini_use_rmt.py
tests/test_zero/test_gemini/test_gemini_use_rmt.py
+1
-1
tests/test_zero/test_gemini/test_grad_clip.py
tests/test_zero/test_gemini/test_grad_clip.py
+1
-1
tests/test_zero/test_gemini/test_inference.py
tests/test_zero/test_gemini/test_inference.py
+1
-1
tests/test_zero/test_gemini/test_optim.py
tests/test_zero/test_gemini/test_optim.py
+2
-2
tests/test_zero/test_gemini/test_search.py
tests/test_zero/test_gemini/test_search.py
+11
-11
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
+2
-2
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
+2
-1
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
+1
-1
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
0bb0b481
...
@@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
...
@@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_m
b
(int, optional): chunk size searching range
in MegaByte
. Defaults to 32.
search_range_m (int, optional): chunk size searching range
divided by 2^20
. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m
b
(float, optional): the minimum chunk size
in MegaByte
.
min_chunk_size_m (float, optional): the minimum chunk size
divided by 2^20
.
If the aggregate size of parameters is still smaller than the minimum chunk size,
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
...
@@ -214,9 +214,9 @@ class GeminiPlugin(DPPluginBase):
...
@@ -214,9 +214,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory
:
bool
=
False
,
pin_memory
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
strict_ddp_mode
:
bool
=
False
,
strict_ddp_mode
:
bool
=
False
,
search_range_m
b
:
int
=
32
,
search_range_m
:
int
=
32
,
hidden_dim
:
Optional
[
int
]
=
None
,
hidden_dim
:
Optional
[
int
]
=
None
,
min_chunk_size_m
b
:
float
=
32
,
min_chunk_size_m
:
float
=
32
,
memstats
:
Optional
[
MemStats
]
=
None
,
memstats
:
Optional
[
MemStats
]
=
None
,
gpu_margin_mem_ratio
:
float
=
0.0
,
gpu_margin_mem_ratio
:
float
=
0.0
,
initial_scale
:
float
=
2
**
32
,
initial_scale
:
float
=
2
**
32
,
...
@@ -238,9 +238,9 @@ class GeminiPlugin(DPPluginBase):
...
@@ -238,9 +238,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
force_outputs_fp32
=
force_outputs_fp32
,
force_outputs_fp32
=
force_outputs_fp32
,
strict_ddp_mode
=
strict_ddp_mode
,
strict_ddp_mode
=
strict_ddp_mode
,
search_range_m
b
=
search_range_m
b
,
search_range_m
=
search_range_m
,
hidden_dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
,
min_chunk_size_m
b
=
min_chunk_size_m
b
,
min_chunk_size_m
=
min_chunk_size_m
,
memstats
=
memstats
,
memstats
=
memstats
,
mixed_precision
=
PRECISION_STR_TO_DTYPE
[
precision
],
mixed_precision
=
PRECISION_STR_TO_DTYPE
[
precision
],
)
)
...
@@ -295,10 +295,7 @@ class GeminiPlugin(DPPluginBase):
...
@@ -295,10 +295,7 @@ class GeminiPlugin(DPPluginBase):
if
optimizer
is
not
None
and
\
if
optimizer
is
not
None
and
\
not
isinstance
(
optimizer
,
OptimizerWrapper
):
not
isinstance
(
optimizer
,
OptimizerWrapper
):
optimizer
=
GeminiOptimizer
(
model
.
unwrap
(),
optimizer
=
GeminiOptimizer
(
model
.
unwrap
(),
optimizer
,
self
.
zero_optim_config
,
self
.
optim_kwargs
,
optimizer
,
self
.
zero_optim_config
,
self
.
optim_kwargs
,
self
.
verbose
)
self
.
verbose
)
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
...
...
colossalai/zero/gemini/chunk/search_utils.py
View file @
0bb0b481
...
@@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
...
@@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
def
search_chunk_configuration
(
def
search_chunk_configuration
(
model
:
nn
.
Module
,
model
:
nn
.
Module
,
search_range_m
b
:
float
,
search_range_m
:
float
,
search_interval
_byte
:
int
,
# hidden size is the best value for the interval
search_interval
:
int
,
# hidden size is the best value for the interval
min_chunk_size_m
b
:
float
=
32
,
min_chunk_size_m
:
float
=
32
,
filter_exlarge_params
:
bool
=
True
,
filter_exlarge_params
:
bool
=
True
,
strict_ddp_flag
:
bool
=
False
,
strict_ddp_flag
:
bool
=
False
,
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
,
int
]:
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
,
int
]:
...
@@ -126,9 +126,9 @@ def search_chunk_configuration(
...
@@ -126,9 +126,9 @@ def search_chunk_configuration(
Args:
Args:
model (nn.Module): torch module
model (nn.Module): torch module
search_range_m
b
(float): searching range
in mega byte
.
search_range_m (float): searching range
divided by 2^20
.
search_interval
_byte
(int): searching interval
in byte
.
search_interval (int): searching interval.
min_chunk_size_m
b
(float, optional): the minimum size of a distributed chunk.
min_chunk_size_m (float, optional): the minimum size of a distributed chunk
, divided by 2^20.
.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
all parameters keep replicated in this mode.
...
@@ -145,9 +145,9 @@ def search_chunk_configuration(
...
@@ -145,9 +145,9 @@ def search_chunk_configuration(
for
p
in
model
.
parameters
():
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
param_order
.
append
(
p
)
search_range
_byte
=
round
(
search_range_m
b
*
1024
**
2
)
search_range
=
round
(
search_range_m
*
1024
**
2
)
min_chunk_size
_byte
=
round
(
min_chunk_size_m
b
*
1024
**
2
)
min_chunk_size
=
round
(
min_chunk_size_m
*
1024
**
2
)
assert
search_range
_byte
>=
0
assert
search_range
>=
0
params_dict
=
classify_params_by_dp_degree
(
param_order
,
strict_ddp_flag
)
params_dict
=
classify_params_by_dp_degree
(
param_order
,
strict_ddp_flag
)
size_lcm
=
np
.
lcm
.
reduce
(
list
(
params_dict
.
keys
()))
size_lcm
=
np
.
lcm
.
reduce
(
list
(
params_dict
.
keys
()))
...
@@ -162,7 +162,7 @@ def search_chunk_configuration(
...
@@ -162,7 +162,7 @@ def search_chunk_configuration(
total_param_size
+=
group_acc_size
total_param_size
+=
group_acc_size
# let small parameters keep gathered in CUDA all the time
# let small parameters keep gathered in CUDA all the time
if
group_acc_size
<
min_chunk_size
_byte
:
if
group_acc_size
<
min_chunk_size
:
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
group_acc_size
,
keep_gathered
=
True
)
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
group_acc_size
,
keep_gathered
=
True
)
else
:
else
:
size_dict
[
dp_degree
]
=
size_list
size_dict
[
dp_degree
]
=
size_list
...
@@ -170,15 +170,15 @@ def search_chunk_configuration(
...
@@ -170,15 +170,15 @@ def search_chunk_configuration(
if
filter_exlarge_params
:
if
filter_exlarge_params
:
_filter_exlarge_params
(
model
,
size_dict
)
_filter_exlarge_params
(
model
,
size_dict
)
max_size
=
min_chunk_size
_byte
max_size
=
min_chunk_size
for
key
in
size_dict
:
for
key
in
size_dict
:
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
start_size
=
int
(
math
.
ceil
(
max_size
/
search_interval
_byte
)
*
search_interval
_byte
)
start_size
=
int
(
math
.
ceil
(
max_size
/
search_interval
)
*
search_interval
)
min_chunk_waste
=
float
(
'+inf'
)
min_chunk_waste
=
float
(
'+inf'
)
best_chunk_size
=
start_size
best_chunk_size
=
start_size
for
chunk_size
in
range
(
start_size
,
start_size
+
search_range
_byte
+
1
,
search_interval
_byte
):
for
chunk_size
in
range
(
start_size
,
start_size
+
search_range
+
1
,
search_interval
):
temp_waste
=
0
temp_waste
=
0
for
key
in
size_dict
:
for
key
in
size_dict
:
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
...
...
colossalai/zero/gemini/chunk/utils.py
View file @
0bb0b481
...
@@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
...
@@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
**
kwargs
)
->
ChunkManager
:
**
kwargs
)
->
ChunkManager
:
if
hidden_dim
:
if
hidden_dim
:
search_interval
_byte
=
hidden_dim
search_interval
=
hidden_dim
else
:
else
:
search_interval
_byte
=
1024
# defaults to 1
kb
search_interval
=
1024
# defaults to 1
024
kwargs
[
"search_interval
_byte
"
]
=
search_interval
_byte
kwargs
[
"search_interval"
]
=
search_interval
dist
.
barrier
()
dist
.
barrier
()
begin
=
time
()
begin
=
time
()
...
@@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
...
@@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
dist
.
barrier
()
dist
.
barrier
()
end
=
time
()
end
=
time
()
span_s
=
end
-
begin
span_s
=
end
-
begin
m
b_size
=
1024
**
2
m
ega_unit
=
1024
**
2
total_size
/=
m
b_size
total_size
/=
m
ega_unit
wasted_size
/=
m
b_size
wasted_size
/=
m
ega_unit
if
verbose
and
dist
.
get_rank
()
==
0
:
if
verbose
and
dist
.
get_rank
()
==
0
:
print
(
"searching chunk configuration is completed in {:.2f} s.
\n
"
.
format
(
span_s
),
print
(
"searching chunk configuration is completed in {:.2f} s.
\n
"
.
format
(
span_s
),
"used number: {:.2f}
MB
, wasted number: {:.2f}
MB
\n
"
.
format
(
total_size
,
wasted_size
),
"used number: {:.2f}
* 2^20
, wasted number: {:.2f}
* 2^20
\n
"
.
format
(
total_size
,
wasted_size
),
"total wasted percentage is {:.2f}%"
.
format
(
100
*
safe_div
(
wasted_size
,
total_size
+
wasted_size
)),
"total wasted percentage is {:.2f}%"
.
format
(
100
*
safe_div
(
wasted_size
,
total_size
+
wasted_size
)),
sep
=
''
,
sep
=
''
,
flush
=
True
)
flush
=
True
)
...
...
colossalai/zero/gemini/gemini_ddp.py
View file @
0bb0b481
...
@@ -739,9 +739,9 @@ class GeminiDDP(ZeroDDP):
...
@@ -739,9 +739,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
strict_ddp_mode
:
bool
=
False
,
strict_ddp_mode
:
bool
=
False
,
scatter_after_inference
:
bool
=
True
,
scatter_after_inference
:
bool
=
True
,
search_range_m
b
:
int
=
32
,
search_range_m
:
int
=
32
,
hidden_dim
:
Optional
[
int
]
=
None
,
hidden_dim
:
Optional
[
int
]
=
None
,
min_chunk_size_m
b
:
float
=
32
,
min_chunk_size_m
:
float
=
32
,
memstats
:
Optional
[
MemStats
]
=
None
,
memstats
:
Optional
[
MemStats
]
=
None
,
mixed_precision
:
torch
.
dtype
=
torch
.
float16
,
mixed_precision
:
torch
.
dtype
=
torch
.
float16
,
verbose
:
bool
=
False
)
->
None
:
verbose
:
bool
=
False
)
->
None
:
...
@@ -763,24 +763,24 @@ class GeminiDDP(ZeroDDP):
...
@@ -763,24 +763,24 @@ class GeminiDDP(ZeroDDP):
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_m
b
(int, optional): chunk size searching range
in MegaByte
. Defaults to 32.
search_range_m (int, optional): chunk size searching range
divided by 2^20
. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m
b
(float, optional): the minimum chunk size
in MegaByte
.
min_chunk_size_m (float, optional): the minimum chunk size
divided by 2^20
.
If the aggregate size of parameters is still smaller than the minimum chunk size,
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
"""
# some ugly hotfix for the compatibility with Lightning
# some ugly hotfix for the compatibility with Lightning
if
search_range_m
b
is
None
:
if
search_range_m
is
None
:
search_range_m
b
=
32
search_range_m
=
32
chunk_manager
=
init_chunk_manager
(
model
=
module
,
chunk_manager
=
init_chunk_manager
(
model
=
module
,
init_device
=
device
,
init_device
=
device
,
hidden_dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
,
search_range_m
b
=
search_range_m
b
,
search_range_m
=
search_range_m
,
min_chunk_size_m
b
=
min_chunk_size_m
b
,
min_chunk_size_m
=
min_chunk_size_m
,
strict_ddp_flag
=
strict_ddp_mode
,
strict_ddp_flag
=
strict_ddp_mode
,
verbose
=
verbose
)
verbose
=
verbose
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
...
...
tests/test_auto_parallel/test_offload/test_perf.py
View file @
0bb0b481
...
@@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
...
@@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
placement_policy
=
'cpu'
,
placement_policy
=
'cpu'
,
pin_memory
=
True
,
pin_memory
=
True
,
hidden_dim
=
8192
,
hidden_dim
=
8192
,
search_range_m
b
=
128
)
search_range_m
=
128
)
gemini_model
=
zero_model_wrapper
(
gemini_model
,
3
,
gemini_config
)
gemini_model
=
zero_model_wrapper
(
gemini_model
,
3
,
gemini_config
)
optim_config
=
dict
(
reduce_bucket_size
=
12
*
1024
*
1024
,
overlap_communication
=
True
,
verbose
=
True
)
optim_config
=
dict
(
reduce_bucket_size
=
12
*
1024
*
1024
,
overlap_communication
=
True
,
verbose
=
True
)
gemini_optim
=
zero_optim_wrapper
(
gemini_model
,
hybrid_optimizer
,
optim_config
=
optim_config
)
gemini_optim
=
zero_optim_wrapper
(
gemini_model
,
hybrid_optimizer
,
optim_config
=
optim_config
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
View file @
0bb0b481
...
@@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
...
@@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
device
=
get_current_device
(),
device
=
get_current_device
(),
placement_policy
=
'cpu'
,
placement_policy
=
'cpu'
,
pin_memory
=
True
,
pin_memory
=
True
,
search_range_m
b
=
128
)
search_range_m
=
128
)
post_process_colo_init_ctx
(
gm
,
device
=
get_current_device
(),
default_pg
=
dp_process_group
)
post_process_colo_init_ctx
(
gm
,
device
=
get_current_device
(),
default_pg
=
dp_process_group
)
gm
=
zero_model_wrapper
(
gm
,
zero_stage
=
3
,
gemini_config
=
gemini_config
)
gm
=
zero_model_wrapper
(
gm
,
zero_stage
=
3
,
gemini_config
=
gemini_config
)
...
...
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
View file @
0bb0b481
...
@@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
...
@@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
bert_model
.
config
.
save_pretrained
(
save_directory
=
pretrained_path
)
bert_model
.
config
.
save_pretrained
(
save_directory
=
pretrained_path
)
# TODO(ver217): use boost api
# TODO(ver217): use boost api
config_dict
,
*
_
=
search_chunk_configuration
(
bert_model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
bert_model
,
search_range_m
=
1
,
search_interval
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
bert_model
=
ZeroDDP
(
bert_model
,
gemini_manager
)
bert_model
=
ZeroDDP
(
bert_model
,
gemini_manager
)
...
...
tests/test_tensor/test_tp_with_zero.py
View file @
0bb0b481
...
@@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
...
@@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func
(
model
,
pg
)
tp_init_spec_func
(
model
,
pg
)
dp_world_size
=
pg
.
dp_world_size
()
dp_world_size
=
pg
.
dp_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
dp_world_size
][
'chunk_size'
]
=
5000
config_dict
[
dp_world_size
][
'chunk_size'
]
=
5000
config_dict
[
dp_world_size
][
'keep_gathered'
]
=
False
config_dict
[
dp_world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
if
placement_policy
!=
'cuda'
:
...
...
tests/test_zero/test_gemini/test_fwd_bwd.py
View file @
0bb0b481
...
@@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
...
@@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
torch_p
.
data
.
copy_
(
p
.
data
)
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
...
@@ -113,7 +113,7 @@ def exam_gpt_inference(
...
@@ -113,7 +113,7 @@ def exam_gpt_inference(
torch_p
.
data
.
copy_
(
p
.
data
)
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
...
...
tests/test_zero/test_gemini/test_gemini_use_rmt.py
View file @
0bb0b481
...
@@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
...
@@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert
len
(
step_list
)
==
4
assert
len
(
step_list
)
==
4
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
...
...
tests/test_zero/test_gemini/test_grad_clip.py
View file @
0bb0b481
...
@@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
...
@@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
if
placement_policy
!=
'cuda'
:
...
...
tests/test_zero/test_gemini/test_inference.py
View file @
0bb0b481
...
@@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
def
multi_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_policy
:
str
):
def
multi_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_policy
:
str
):
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
if
placement_policy
!=
'cuda'
:
...
...
tests/test_zero/test_gemini/test_optim.py
View file @
0bb0b481
...
@@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
...
@@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
if
placement_policy
!=
'cuda'
:
...
@@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
...
@@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_m
b
=
1
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_m
=
1
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
,
mixed_precision
=
mixed_precision
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
,
mixed_precision
=
mixed_precision
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
...
...
tests/test_zero/test_gemini/test_search.py
View file @
0bb0b481
...
@@ -30,9 +30,9 @@ def exam_search_chunk_size():
...
@@ -30,9 +30,9 @@ def exam_search_chunk_size():
model
=
model_builder
()
model
=
model_builder
()
init_1d_row_spec
(
model
,
pg_tp
)
init_1d_row_spec
(
model
,
pg_tp
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_range_m
=
1
,
search_interval
_byte
=
16
,
search_interval
=
16
,
min_chunk_size_m
b
=
0
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
)
filter_exlarge_params
=
True
)
for
key
in
config_dict
:
for
key
in
config_dict
:
...
@@ -54,9 +54,9 @@ def exam_search_strict_ddp():
...
@@ -54,9 +54,9 @@ def exam_search_strict_ddp():
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
ddp_model
=
model_builder
()
ddp_model
=
model_builder
()
re_dict
,
re_total
,
re_wasted
=
search_chunk_configuration
(
ddp_model
,
re_dict
,
re_total
,
re_wasted
=
search_chunk_configuration
(
ddp_model
,
search_range_m
b
=
1
,
search_range_m
=
1
,
search_interval
_byte
=
16
,
search_interval
=
16
,
min_chunk_size_m
b
=
0
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
False
)
strict_ddp_flag
=
False
)
# get the chunk configuration over sharded ddp models
# get the chunk configuration over sharded ddp models
...
@@ -64,9 +64,9 @@ def exam_search_strict_ddp():
...
@@ -64,9 +64,9 @@ def exam_search_strict_ddp():
default_dist_spec
=
default_shard_spec
):
default_dist_spec
=
default_shard_spec
):
sharded_ddp_model
=
model_builder
()
sharded_ddp_model
=
model_builder
()
sh_dict
,
sh_total
,
sh_wasted
=
search_chunk_configuration
(
sharded_ddp_model
,
sh_dict
,
sh_total
,
sh_wasted
=
search_chunk_configuration
(
sharded_ddp_model
,
search_range_m
b
=
1
,
search_range_m
=
1
,
search_interval
_byte
=
16
,
search_interval
=
16
,
min_chunk_size_m
b
=
0
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
True
)
strict_ddp_flag
=
True
)
assert
re_dict
==
sh_dict
assert
re_dict
==
sh_dict
...
@@ -91,8 +91,8 @@ def exam_chunk_manager():
...
@@ -91,8 +91,8 @@ def exam_chunk_manager():
chunk_manager
=
init_chunk_manager
(
sharded_ddp_model
,
chunk_manager
=
init_chunk_manager
(
sharded_ddp_model
,
get_current_device
(),
get_current_device
(),
hidden_dim
=
16
,
hidden_dim
=
16
,
search_range_m
b
=
1
,
search_range_m
=
1
,
min_chunk_size_m
b
=
0
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
True
)
strict_ddp_flag
=
True
)
config_dict
=
chunk_manager
.
dp_degree_chunk_size_dict
config_dict
=
chunk_manager
.
dp_degree_chunk_size_dict
...
...
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
View file @
0bb0b481
...
@@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
...
@@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p
.
data
.
copy_
(
p
.
data
)
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
...
@@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
...
@@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model
=
model_builder
()
# get a different model
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
...
...
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
View file @
0bb0b481
...
@@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
...
@@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
model_size
=
sum
(
p
.
numel
()
*
p
.
element_size
()
for
p
in
model
.
parameters
())
/
1024
**
2
model_size
=
sum
(
p
.
numel
()
*
p
.
element_size
()
for
p
in
model
.
parameters
())
/
1024
**
2
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
)
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
...
@@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
...
@@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
assert
key
in
zero_dict
,
f
"
{
key
}
not in ZeRO dictionary."
assert
key
in
zero_dict
,
f
"
{
key
}
not in ZeRO dictionary."
assert
torch
.
equal
(
value
,
zero_dict
[
key
]),
f
"
{
key
}
not equal."
assert
torch
.
equal
(
value
,
zero_dict
[
key
]),
f
"
{
key
}
not equal."
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
...
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
View file @
0bb0b481
...
@@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
...
@@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model
=
model_builder
()
# get a different model
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment