Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0bb0b481
Commit
0bb0b481
authored
Jun 25, 2023
by
Baizhou Zhang
Browse files
[gemini] fix argument naming during chunk configuration searching
parent
b463651f
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
62 additions
and
64 deletions
+62
-64
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+7
-10
colossalai/zero/gemini/chunk/search_utils.py
colossalai/zero/gemini/chunk/search_utils.py
+13
-13
colossalai/zero/gemini/chunk/utils.py
colossalai/zero/gemini/chunk/utils.py
+7
-7
colossalai/zero/gemini/gemini_ddp.py
colossalai/zero/gemini/gemini_ddp.py
+8
-8
tests/test_auto_parallel/test_offload/test_perf.py
tests/test_auto_parallel/test_offload/test_perf.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
...allel/test_tensor_shard/test_compatibility_with_gemini.py
+1
-1
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+1
-1
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+1
-1
tests/test_zero/test_gemini/test_fwd_bwd.py
tests/test_zero/test_gemini/test_fwd_bwd.py
+2
-2
tests/test_zero/test_gemini/test_gemini_use_rmt.py
tests/test_zero/test_gemini/test_gemini_use_rmt.py
+1
-1
tests/test_zero/test_gemini/test_grad_clip.py
tests/test_zero/test_gemini/test_grad_clip.py
+1
-1
tests/test_zero/test_gemini/test_inference.py
tests/test_zero/test_gemini/test_inference.py
+1
-1
tests/test_zero/test_gemini/test_optim.py
tests/test_zero/test_gemini/test_optim.py
+2
-2
tests/test_zero/test_gemini/test_search.py
tests/test_zero/test_gemini/test_search.py
+11
-11
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
+2
-2
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
+2
-1
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
+1
-1
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
0bb0b481
...
...
@@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_m
b
(int, optional): chunk size searching range
in MegaByte
. Defaults to 32.
search_range_m (int, optional): chunk size searching range
divided by 2^20
. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m
b
(float, optional): the minimum chunk size
in MegaByte
.
min_chunk_size_m (float, optional): the minimum chunk size
divided by 2^20
.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
...
...
@@ -214,9 +214,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory
:
bool
=
False
,
force_outputs_fp32
:
bool
=
False
,
strict_ddp_mode
:
bool
=
False
,
search_range_m
b
:
int
=
32
,
search_range_m
:
int
=
32
,
hidden_dim
:
Optional
[
int
]
=
None
,
min_chunk_size_m
b
:
float
=
32
,
min_chunk_size_m
:
float
=
32
,
memstats
:
Optional
[
MemStats
]
=
None
,
gpu_margin_mem_ratio
:
float
=
0.0
,
initial_scale
:
float
=
2
**
32
,
...
...
@@ -238,9 +238,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory
=
pin_memory
,
force_outputs_fp32
=
force_outputs_fp32
,
strict_ddp_mode
=
strict_ddp_mode
,
search_range_m
b
=
search_range_m
b
,
search_range_m
=
search_range_m
,
hidden_dim
=
hidden_dim
,
min_chunk_size_m
b
=
min_chunk_size_m
b
,
min_chunk_size_m
=
min_chunk_size_m
,
memstats
=
memstats
,
mixed_precision
=
PRECISION_STR_TO_DTYPE
[
precision
],
)
...
...
@@ -295,10 +295,7 @@ class GeminiPlugin(DPPluginBase):
if
optimizer
is
not
None
and
\
not
isinstance
(
optimizer
,
OptimizerWrapper
):
optimizer
=
GeminiOptimizer
(
model
.
unwrap
(),
optimizer
,
self
.
zero_optim_config
,
self
.
optim_kwargs
,
optimizer
=
GeminiOptimizer
(
model
.
unwrap
(),
optimizer
,
self
.
zero_optim_config
,
self
.
optim_kwargs
,
self
.
verbose
)
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
...
...
colossalai/zero/gemini/chunk/search_utils.py
View file @
0bb0b481
...
...
@@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
def
search_chunk_configuration
(
model
:
nn
.
Module
,
search_range_m
b
:
float
,
search_interval
_byte
:
int
,
# hidden size is the best value for the interval
min_chunk_size_m
b
:
float
=
32
,
search_range_m
:
float
,
search_interval
:
int
,
# hidden size is the best value for the interval
min_chunk_size_m
:
float
=
32
,
filter_exlarge_params
:
bool
=
True
,
strict_ddp_flag
:
bool
=
False
,
memstas
:
Optional
[
MemStats
]
=
None
)
->
Tuple
[
Dict
,
int
,
int
]:
...
...
@@ -126,9 +126,9 @@ def search_chunk_configuration(
Args:
model (nn.Module): torch module
search_range_m
b
(float): searching range
in mega byte
.
search_interval
_byte
(int): searching interval
in byte
.
min_chunk_size_m
b
(float, optional): the minimum size of a distributed chunk.
search_range_m (float): searching range
divided by 2^20
.
search_interval (int): searching interval.
min_chunk_size_m (float, optional): the minimum size of a distributed chunk
, divided by 2^20.
.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
...
...
@@ -145,9 +145,9 @@ def search_chunk_configuration(
for
p
in
model
.
parameters
():
param_order
.
append
(
p
)
search_range
_byte
=
round
(
search_range_m
b
*
1024
**
2
)
min_chunk_size
_byte
=
round
(
min_chunk_size_m
b
*
1024
**
2
)
assert
search_range
_byte
>=
0
search_range
=
round
(
search_range_m
*
1024
**
2
)
min_chunk_size
=
round
(
min_chunk_size_m
*
1024
**
2
)
assert
search_range
>=
0
params_dict
=
classify_params_by_dp_degree
(
param_order
,
strict_ddp_flag
)
size_lcm
=
np
.
lcm
.
reduce
(
list
(
params_dict
.
keys
()))
...
...
@@ -162,7 +162,7 @@ def search_chunk_configuration(
total_param_size
+=
group_acc_size
# let small parameters keep gathered in CUDA all the time
if
group_acc_size
<
min_chunk_size
_byte
:
if
group_acc_size
<
min_chunk_size
:
config_dict
[
dp_degree
]
=
dict
(
chunk_size
=
group_acc_size
,
keep_gathered
=
True
)
else
:
size_dict
[
dp_degree
]
=
size_list
...
...
@@ -170,15 +170,15 @@ def search_chunk_configuration(
if
filter_exlarge_params
:
_filter_exlarge_params
(
model
,
size_dict
)
max_size
=
min_chunk_size
_byte
max_size
=
min_chunk_size
for
key
in
size_dict
:
max_size
=
max
(
max_size
,
max
(
size_dict
[
key
]))
start_size
=
int
(
math
.
ceil
(
max_size
/
search_interval
_byte
)
*
search_interval
_byte
)
start_size
=
int
(
math
.
ceil
(
max_size
/
search_interval
)
*
search_interval
)
min_chunk_waste
=
float
(
'+inf'
)
best_chunk_size
=
start_size
for
chunk_size
in
range
(
start_size
,
start_size
+
search_range
_byte
+
1
,
search_interval
_byte
):
for
chunk_size
in
range
(
start_size
,
start_size
+
search_range
+
1
,
search_interval
):
temp_waste
=
0
for
key
in
size_dict
:
temp_waste
+=
_get_unused_byte
(
size_dict
[
key
],
chunk_size
)
...
...
colossalai/zero/gemini/chunk/utils.py
View file @
0bb0b481
...
...
@@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
verbose
:
bool
=
False
,
**
kwargs
)
->
ChunkManager
:
if
hidden_dim
:
search_interval
_byte
=
hidden_dim
search_interval
=
hidden_dim
else
:
search_interval
_byte
=
1024
# defaults to 1
kb
kwargs
[
"search_interval
_byte
"
]
=
search_interval
_byte
search_interval
=
1024
# defaults to 1
024
kwargs
[
"search_interval"
]
=
search_interval
dist
.
barrier
()
begin
=
time
()
...
...
@@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
dist
.
barrier
()
end
=
time
()
span_s
=
end
-
begin
m
b_size
=
1024
**
2
total_size
/=
m
b_size
wasted_size
/=
m
b_size
m
ega_unit
=
1024
**
2
total_size
/=
m
ega_unit
wasted_size
/=
m
ega_unit
if
verbose
and
dist
.
get_rank
()
==
0
:
print
(
"searching chunk configuration is completed in {:.2f} s.
\n
"
.
format
(
span_s
),
"used number: {:.2f}
MB
, wasted number: {:.2f}
MB
\n
"
.
format
(
total_size
,
wasted_size
),
"used number: {:.2f}
* 2^20
, wasted number: {:.2f}
* 2^20
\n
"
.
format
(
total_size
,
wasted_size
),
"total wasted percentage is {:.2f}%"
.
format
(
100
*
safe_div
(
wasted_size
,
total_size
+
wasted_size
)),
sep
=
''
,
flush
=
True
)
...
...
colossalai/zero/gemini/gemini_ddp.py
View file @
0bb0b481
...
...
@@ -739,9 +739,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32
:
bool
=
False
,
strict_ddp_mode
:
bool
=
False
,
scatter_after_inference
:
bool
=
True
,
search_range_m
b
:
int
=
32
,
search_range_m
:
int
=
32
,
hidden_dim
:
Optional
[
int
]
=
None
,
min_chunk_size_m
b
:
float
=
32
,
min_chunk_size_m
:
float
=
32
,
memstats
:
Optional
[
MemStats
]
=
None
,
mixed_precision
:
torch
.
dtype
=
torch
.
float16
,
verbose
:
bool
=
False
)
->
None
:
...
...
@@ -763,24 +763,24 @@ class GeminiDDP(ZeroDDP):
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_m
b
(int, optional): chunk size searching range
in MegaByte
. Defaults to 32.
search_range_m (int, optional): chunk size searching range
divided by 2^20
. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m
b
(float, optional): the minimum chunk size
in MegaByte
.
min_chunk_size_m (float, optional): the minimum chunk size
divided by 2^20
.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if
search_range_m
b
is
None
:
search_range_m
b
=
32
if
search_range_m
is
None
:
search_range_m
=
32
chunk_manager
=
init_chunk_manager
(
model
=
module
,
init_device
=
device
,
hidden_dim
=
hidden_dim
,
search_range_m
b
=
search_range_m
b
,
min_chunk_size_m
b
=
min_chunk_size_m
b
,
search_range_m
=
search_range_m
,
min_chunk_size_m
=
min_chunk_size_m
,
strict_ddp_flag
=
strict_ddp_mode
,
verbose
=
verbose
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
...
...
tests/test_auto_parallel/test_offload/test_perf.py
View file @
0bb0b481
...
...
@@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
placement_policy
=
'cpu'
,
pin_memory
=
True
,
hidden_dim
=
8192
,
search_range_m
b
=
128
)
search_range_m
=
128
)
gemini_model
=
zero_model_wrapper
(
gemini_model
,
3
,
gemini_config
)
optim_config
=
dict
(
reduce_bucket_size
=
12
*
1024
*
1024
,
overlap_communication
=
True
,
verbose
=
True
)
gemini_optim
=
zero_optim_wrapper
(
gemini_model
,
hybrid_optimizer
,
optim_config
=
optim_config
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
View file @
0bb0b481
...
...
@@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
device
=
get_current_device
(),
placement_policy
=
'cpu'
,
pin_memory
=
True
,
search_range_m
b
=
128
)
search_range_m
=
128
)
post_process_colo_init_ctx
(
gm
,
device
=
get_current_device
(),
default_pg
=
dp_process_group
)
gm
=
zero_model_wrapper
(
gm
,
zero_stage
=
3
,
gemini_config
=
gemini_config
)
...
...
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
View file @
0bb0b481
...
...
@@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
bert_model
.
config
.
save_pretrained
(
save_directory
=
pretrained_path
)
# TODO(ver217): use boost api
config_dict
,
*
_
=
search_chunk_configuration
(
bert_model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
bert_model
,
search_range_m
=
1
,
search_interval
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
bert_model
=
ZeroDDP
(
bert_model
,
gemini_manager
)
...
...
tests/test_tensor/test_tp_with_zero.py
View file @
0bb0b481
...
...
@@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func
(
model
,
pg
)
dp_world_size
=
pg
.
dp_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
dp_world_size
][
'chunk_size'
]
=
5000
config_dict
[
dp_world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
...
...
tests/test_zero/test_gemini/test_fwd_bwd.py
View file @
0bb0b481
...
...
@@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
...
...
@@ -113,7 +113,7 @@ def exam_gpt_inference(
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
...
...
tests/test_zero/test_gemini/test_gemini_use_rmt.py
View file @
0bb0b481
...
...
@@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert
len
(
step_list
)
==
4
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
...
...
tests/test_zero/test_gemini/test_grad_clip.py
View file @
0bb0b481
...
...
@@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p
.
data
.
copy_
(
torch_p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
...
...
tests/test_zero/test_gemini/test_inference.py
View file @
0bb0b481
...
...
@@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
def
multi_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_policy
:
str
):
world_size
=
dist
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
...
...
tests/test_zero/test_gemini/test_optim.py
View file @
0bb0b481
...
...
@@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
p
.
data
.
copy_
(
torch_p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
...
...
@@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_m
b
=
1
)
chunk_manager
=
init_chunk_manager
(
model
=
model
,
init_device
=
get_current_device
(),
search_range_m
=
1
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
,
mixed_precision
=
mixed_precision
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
...
...
tests/test_zero/test_gemini/test_search.py
View file @
0bb0b481
...
...
@@ -30,9 +30,9 @@ def exam_search_chunk_size():
model
=
model_builder
()
init_1d_row_spec
(
model
,
pg_tp
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
16
,
min_chunk_size_m
b
=
0
,
search_range_m
=
1
,
search_interval
=
16
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
)
for
key
in
config_dict
:
...
...
@@ -54,9 +54,9 @@ def exam_search_strict_ddp():
with
ColoInitContext
(
device
=
get_current_device
()):
ddp_model
=
model_builder
()
re_dict
,
re_total
,
re_wasted
=
search_chunk_configuration
(
ddp_model
,
search_range_m
b
=
1
,
search_interval
_byte
=
16
,
min_chunk_size_m
b
=
0
,
search_range_m
=
1
,
search_interval
=
16
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
False
)
# get the chunk configuration over sharded ddp models
...
...
@@ -64,9 +64,9 @@ def exam_search_strict_ddp():
default_dist_spec
=
default_shard_spec
):
sharded_ddp_model
=
model_builder
()
sh_dict
,
sh_total
,
sh_wasted
=
search_chunk_configuration
(
sharded_ddp_model
,
search_range_m
b
=
1
,
search_interval
_byte
=
16
,
min_chunk_size_m
b
=
0
,
search_range_m
=
1
,
search_interval
=
16
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
True
)
assert
re_dict
==
sh_dict
...
...
@@ -91,8 +91,8 @@ def exam_chunk_manager():
chunk_manager
=
init_chunk_manager
(
sharded_ddp_model
,
get_current_device
(),
hidden_dim
=
16
,
search_range_m
b
=
1
,
min_chunk_size_m
b
=
0
,
search_range_m
=
1
,
min_chunk_size_m
=
0
,
filter_exlarge_params
=
True
,
strict_ddp_flag
=
True
)
config_dict
=
chunk_manager
.
dp_degree_chunk_size_dict
...
...
tests/test_zero/test_gemini/test_zeroddp_state_dict.py
View file @
0bb0b481
...
...
@@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p
.
data
.
copy_
(
p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
chunk_manager
=
ChunkManager
(
config_dict
)
...
...
@@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
...
...
tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
View file @
0bb0b481
...
...
@@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
model_size
=
sum
(
p
.
numel
()
*
p
.
element_size
()
for
p
in
model
.
parameters
())
/
1024
**
2
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
...
...
@@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
assert
key
in
zero_dict
,
f
"
{
key
}
not in ZeRO dictionary."
assert
torch
.
equal
(
value
,
zero_dict
[
key
]),
f
"
{
key
}
not equal."
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
...
tests/test_zero/test_gemini/test_zerooptim_state_dict.py
View file @
0bb0b481
...
...
@@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model
=
model_builder
()
# get a different model
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
b
=
1
,
search_interval
_byte
=
100
)
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_m
=
1
,
search_interval
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gathered
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment