Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5efda697
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7b9b86441fbffdd07021f234ec88d0dbc470fa5c"
Unverified
Commit
5efda697
authored
Dec 13, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 13, 2022
Browse files
[Gemini] hotfix the unittest bugs (#2125)
parent
05bb28aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
9 deletions
+11
-9
colossalai/gemini/memory_tracer/param_runtime_order.py
colossalai/gemini/memory_tracer/param_runtime_order.py
+1
-1
tests/test_gemini/update/test_gemini_use_rmt.py
tests/test_gemini/update/test_gemini_use_rmt.py
+10
-8
No files found.
colossalai/gemini/memory_tracer/param_runtime_order.py
View file @
5efda697
...
...
@@ -36,7 +36,7 @@ class OrderedParamGenerator(ParamGenerator):
del
visited_set
def
is_empty
(
self
):
return
len
(
self
.
param_visited_order
)
>
0
return
len
(
self
.
param_visited_order
)
==
0
def
clear
(
self
):
self
.
param_visited_order
=
[]
tests/test_gemini/update/test_gemini_use_rmt.py
View file @
5efda697
...
...
@@ -45,11 +45,15 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
run_fwd_bwd
(
runtime_mem_tracer
,
input_ids
,
label
,
criterion
,
runtime_mem_tracer
)
memstats
=
runtime_mem_tracer
.
memstats
()
runtime_tracer_non_model_data
=
runtime_mem_tracer
.
_memstats
.
_non_model_data_cuda_list
print
(
'runtime tracer: '
,
runtime_tracer_non_model_data
)
print
([
memstats
.
param_used_timestep
(
p
)
for
p
in
model
.
parameters
()])
print
(
'runtime tracer non model data points: '
,
len
(
runtime_tracer_non_model_data
))
model
=
GeminiDDP
(
model
,
device
=
'cuda'
,
placement_policy
=
placement_policy
,
search_range_mb
=
1
,
memstats
=
memstats
)
zero_optim
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
1
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
keep_gather
chunk_manager
=
ChunkManager
(
config_dict
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
,
memstats
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
pg
=
ProcessGroup
()
set_seed
(
pg
.
dp_local_rank
())
...
...
@@ -61,12 +65,10 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
zero_optim
.
zero_grad
()
set_seed
(
42
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
zero_optim
.
step
()
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
model
)
gemini_non_model_data
=
model
.
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
gemini_non_model_data
=
gemini_manager
.
_mem_stats_collector
.
_memstats
.
non_model_data_list
(
'cuda'
)
# print('gemini non model data:', gemini_non_model_data)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment