Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
638a07a7
Unverified
Commit
638a07a7
authored
Apr 03, 2023
by
Frank Lee
Committed by
GitHub
Apr 03, 2023
Browse files
[test] fixed gemini plugin test (#3411)
* [test] fixed gemini plugin test * polish code * polish code
parent
30412866
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
124 additions
and
131 deletions
+124
-131
colossalai/auto_parallel/offload/base_offload_module.py
colossalai/auto_parallel/offload/base_offload_module.py
+7
-9
colossalai/auto_parallel/offload/mem_optimize.py
colossalai/auto_parallel/offload/mem_optimize.py
+12
-9
colossalai/auto_parallel/offload/runtime.py
colossalai/auto_parallel/offload/runtime.py
+31
-28
colossalai/auto_parallel/offload/util.py
colossalai/auto_parallel/offload/util.py
+21
-12
tests/test_auto_parallel/test_offload/test_perf.py
tests/test_auto_parallel/test_offload/test_perf.py
+32
-32
tests/test_booster/test_plugin/test_gemini_plugin.py
tests/test_booster/test_plugin/test_gemini_plugin.py
+19
-41
tests/test_zero/low_level_zero/test_zero1_2.py
tests/test_zero/low_level_zero/test_zero1_2.py
+2
-0
No files found.
colossalai/auto_parallel/offload/base_offload_module.py
View file @
638a07a7
from
typing
import
Optional
,
Set
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
,
Set
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.gemini.tensor_utils
import
free_storage
from
colossalai.gemini.tensor_utils
import
free_storage
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
.region_manager
import
RegionManager
from
.region_manager
import
RegionManager
from
.util
import
GlobalRuntimeInfo
from
.util
import
GlobalRuntimeInfo
...
@@ -20,10 +21,7 @@ class BaseOffloadModule:
...
@@ -20,10 +21,7 @@ class BaseOffloadModule:
is_sync (bool): synchronous mode or not.
is_sync (bool): synchronous mode or not.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
model
:
nn
.
Module
,
region_manager
:
RegionManager
,
is_sync
=
True
):
model
:
nn
.
Module
,
region_manager
:
RegionManager
,
is_sync
=
True
):
self
.
model
=
model
self
.
model
=
model
self
.
region_manager
=
region_manager
self
.
region_manager
=
region_manager
...
@@ -69,8 +67,8 @@ class BaseOffloadModule:
...
@@ -69,8 +67,8 @@ class BaseOffloadModule:
for
p
in
self
.
model
.
parameters
():
for
p
in
self
.
model
.
parameters
():
p
.
grad
=
None
p
.
grad
=
None
GlobalRuntimeInfo
.
fwd_prefetch_event_map
.
clear
()
GlobalRuntimeInfo
()
.
fwd_prefetch_event_map
.
clear
()
GlobalRuntimeInfo
.
bwd_prefetch_event_map
.
clear
()
GlobalRuntimeInfo
()
.
bwd_prefetch_event_map
.
clear
()
def
grad_handle
(
self
,
p
,
grad
):
def
grad_handle
(
self
,
p
,
grad
):
empty_grad
=
torch
.
empty_like
(
grad
)
empty_grad
=
torch
.
empty_like
(
grad
)
...
@@ -82,7 +80,7 @@ class BaseOffloadModule:
...
@@ -82,7 +80,7 @@ class BaseOffloadModule:
self
.
overflow_counter
+=
region
.
has_inf_or_nan
self
.
overflow_counter
+=
region
.
has_inf_or_nan
master_stream
=
torch
.
cuda
.
current_stream
()
master_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
grad_offload_stream
):
with
torch
.
cuda
.
stream
(
self
.
grad_offload_stream
):
GlobalRuntimeInfo
.
d2h_stream
.
wait_stream
(
master_stream
)
GlobalRuntimeInfo
()
.
d2h_stream
.
wait_stream
(
master_stream
)
region
.
move_grad_to_cpu
()
region
.
move_grad_to_cpu
()
return
empty_grad
return
empty_grad
...
...
colossalai/auto_parallel/offload/mem_optimize.py
View file @
638a07a7
from
typing
import
Dict
from
typing
import
Dict
import
torch
import
torch
import
torch.fx
import
torch.fx
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
...
@@ -7,10 +8,11 @@ from torch.utils._pytree import tree_map
...
@@ -7,10 +8,11 @@ from torch.utils._pytree import tree_map
from
colossalai.fx
import
ColoTracer
,
is_compatible_with_meta
from
colossalai.fx
import
ColoTracer
,
is_compatible_with_meta
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
.region_manager
import
RegionManager
from
.runtime
import
runtime_syn_offload_apply_pass
,
runtime_asyn_offload_apply_pass
from
.base_offload_module
import
BaseOffloadModule
from
.base_offload_module
import
BaseOffloadModule
from
.util
import
compute_max_param_mem
,
compute_total_param_mem
,
compute_act_peak_mem
,
GlobalRuntimeInfo
from
.region_manager
import
RegionManager
from
.runtime
import
runtime_asyn_offload_apply_pass
,
runtime_syn_offload_apply_pass
from
.util
import
GlobalRuntimeInfo
,
compute_act_peak_mem
,
compute_max_param_mem
,
compute_total_param_mem
def
memory_optimize
(
model
:
torch
.
nn
.
Module
,
def
memory_optimize
(
model
:
torch
.
nn
.
Module
,
inps
:
Dict
[
str
,
torch
.
Tensor
],
inps
:
Dict
[
str
,
torch
.
Tensor
],
...
@@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,
...
@@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,
region_manager
=
RegionManager
(
graph
,
solver_name
=
solver_name
,
memory_budget
=
memory_budget
)
region_manager
=
RegionManager
(
graph
,
solver_name
=
solver_name
,
memory_budget
=
memory_budget
)
region_manager
.
_build_regions
()
region_manager
.
_build_regions
()
GlobalRuntimeInfo
.
region_list
=
region_manager
.
region_list
GlobalRuntimeInfo
()
.
region_list
=
region_manager
.
region_list
act_peak_mem
=
compute_act_peak_mem
(
region_manager
.
region_list
)
/
1024
**
2
act_peak_mem
=
compute_act_peak_mem
(
region_manager
.
region_list
)
/
1024
**
2
max_param_mem
=
compute_max_param_mem
(
region_manager
.
region_list
)
/
1024
**
2
max_param_mem
=
compute_max_param_mem
(
region_manager
.
region_list
)
/
1024
**
2
total_param_mem
=
compute_total_param_mem
(
region_manager
.
region_list
)
/
1024
**
2
total_param_mem
=
compute_total_param_mem
(
region_manager
.
region_list
)
/
1024
**
2
print
(
print
(
f
"act_peak_mem=
{
act_peak_mem
:.
3
f
}
MB | max_param_mem=
{
max_param_mem
:.
3
f
}
MB | total_param_mem=
{
total_param_mem
:.
3
f
}
"
)
f
"act_peak_mem=
{
act_peak_mem
:.
3
f
}
MB | max_param_mem=
{
max_param_mem
:.
3
f
}
MB | total_param_mem=
{
total_param_mem
:.
3
f
}
"
)
if
solver_name
==
'syn'
:
if
solver_name
==
'syn'
:
gm
=
runtime_syn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
gm
=
runtime_syn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
...
@@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
...
@@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
raise
TypeError
(
f
"Unknown solver name
{
solver_name
}
!"
)
raise
TypeError
(
f
"Unknown solver name
{
solver_name
}
!"
)
gm
.
recompile
()
gm
.
recompile
()
optimized_model
=
BaseOffloadModule
(
gm
,
region_manager
,
solver_name
==
'syn'
)
optimized_model
=
BaseOffloadModule
(
gm
,
region_manager
,
solver_name
==
'syn'
)
return
optimized_model
return
optimized_model
colossalai/auto_parallel/offload/runtime.py
View file @
638a07a7
from
typing
import
List
from
typing
import
List
import
torch
import
torch
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
...
@@ -23,13 +24,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -23,13 +24,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
ctx
.
bwd_info
=
bwd_info
ctx
.
bwd_info
=
bwd_info
d2h_rid
=
fwd_info
.
get
(
'd2h_rid'
,
None
)
d2h_rid
=
fwd_info
.
get
(
'd2h_rid'
,
None
)
if
d2h_rid
is
not
None
:
if
d2h_rid
is
not
None
:
free_region
=
GlobalRuntimeInfo
.
region_list
[
d2h_rid
]
free_region
=
GlobalRuntimeInfo
()
.
region_list
[
d2h_rid
]
assert
isinstance
(
free_region
,
Region
)
assert
isinstance
(
free_region
,
Region
)
free_region
.
free_cuda_data
()
free_region
.
free_cuda_data
()
h2d_rid
=
fwd_info
.
get
(
'h2d_rid'
,
None
)
h2d_rid
=
fwd_info
.
get
(
'h2d_rid'
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
h2d_region
=
GlobalRuntimeInfo
.
region_list
[
h2d_rid
]
h2d_region
=
GlobalRuntimeInfo
()
.
region_list
[
h2d_rid
]
assert
isinstance
(
h2d_region
,
Region
)
assert
isinstance
(
h2d_region
,
Region
)
h2d_region
.
move_param_to_cuda
()
h2d_region
.
move_param_to_cuda
()
...
@@ -40,7 +41,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -40,7 +41,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
h2d_rid
=
ctx
.
bwd_info
.
get
(
'h2d_rid'
,
None
)
h2d_rid
=
ctx
.
bwd_info
.
get
(
'h2d_rid'
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
.
region_list
[
h2d_rid
]
pref_region
=
GlobalRuntimeInfo
()
.
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
assert
isinstance
(
pref_region
,
Region
)
pref_region
.
move_param_to_cuda
()
pref_region
.
move_param_to_cuda
()
...
@@ -65,23 +66,22 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -65,23 +66,22 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
sync_rid
=
fwd_info
.
get
(
'sync_rid'
,
None
)
sync_rid
=
fwd_info
.
get
(
'sync_rid'
,
None
)
if
sync_rid
is
not
None
:
if
sync_rid
is
not
None
:
prefetch_event
=
GlobalRuntimeInfo
.
fwd_prefetch_event_map
.
get
(
prefetch_event
=
GlobalRuntimeInfo
().
fwd_prefetch_event_map
.
get
(
sync_rid
,
None
)
sync_rid
,
None
)
if
prefetch_event
:
if
prefetch_event
:
prefetch_event
.
wait
()
prefetch_event
.
wait
()
h2d_rid
=
fwd_info
.
get
(
'h2d_rid'
,
None
)
h2d_rid
=
fwd_info
.
get
(
'h2d_rid'
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
.
region_list
[
h2d_rid
]
pref_region
=
GlobalRuntimeInfo
()
.
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
assert
isinstance
(
pref_region
,
Region
)
master_stream
=
torch
.
cuda
.
current_stream
()
master_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
GlobalRuntimeInfo
.
h2d_stream
):
with
torch
.
cuda
.
stream
(
GlobalRuntimeInfo
()
.
h2d_stream
):
GlobalRuntimeInfo
.
h2d_stream
.
wait_stream
(
master_stream
)
GlobalRuntimeInfo
()
.
h2d_stream
.
wait_stream
(
master_stream
)
pref_region
.
move_param_to_cuda
()
pref_region
.
move_param_to_cuda
()
prefetch_event
=
torch
.
cuda
.
Event
()
prefetch_event
=
torch
.
cuda
.
Event
()
prefetch_event
.
record
(
GlobalRuntimeInfo
.
h2d_stream
)
prefetch_event
.
record
(
GlobalRuntimeInfo
()
.
h2d_stream
)
GlobalRuntimeInfo
.
fwd_prefetch_event_map
[
h2d_rid
]
=
prefetch_event
GlobalRuntimeInfo
()
.
fwd_prefetch_event_map
[
h2d_rid
]
=
prefetch_event
return
input_
return
input_
...
@@ -90,10 +90,9 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -90,10 +90,9 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
sync_rid
=
ctx
.
bwd_info
.
get
(
'sync_rid'
,
None
)
sync_rid
=
ctx
.
bwd_info
.
get
(
'sync_rid'
,
None
)
if
sync_rid
is
not
None
:
if
sync_rid
is
not
None
:
wait_region
=
GlobalRuntimeInfo
.
region_list
[
sync_rid
]
wait_region
=
GlobalRuntimeInfo
()
.
region_list
[
sync_rid
]
assert
isinstance
(
wait_region
,
Region
)
assert
isinstance
(
wait_region
,
Region
)
prefetch_event
=
GlobalRuntimeInfo
.
bwd_prefetch_event_map
.
get
(
prefetch_event
=
GlobalRuntimeInfo
().
bwd_prefetch_event_map
.
get
(
sync_rid
,
None
)
sync_rid
,
None
)
if
prefetch_event
:
if
prefetch_event
:
prefetch_event
.
wait
()
prefetch_event
.
wait
()
else
:
else
:
...
@@ -101,16 +100,16 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -101,16 +100,16 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
h2d_rid
=
ctx
.
bwd_info
.
get
(
'h2d_rid'
,
None
)
h2d_rid
=
ctx
.
bwd_info
.
get
(
'h2d_rid'
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
.
region_list
[
h2d_rid
]
pref_region
=
GlobalRuntimeInfo
()
.
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
assert
isinstance
(
pref_region
,
Region
)
master_stream
=
torch
.
cuda
.
current_stream
()
master_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
GlobalRuntimeInfo
.
h2d_stream
):
with
torch
.
cuda
.
stream
(
GlobalRuntimeInfo
()
.
h2d_stream
):
GlobalRuntimeInfo
.
h2d_stream
.
wait_stream
(
master_stream
)
GlobalRuntimeInfo
()
.
h2d_stream
.
wait_stream
(
master_stream
)
pref_region
.
move_param_to_cuda
()
pref_region
.
move_param_to_cuda
()
prefetch_event
=
torch
.
cuda
.
Event
()
prefetch_event
=
torch
.
cuda
.
Event
()
prefetch_event
.
record
(
GlobalRuntimeInfo
.
h2d_stream
)
prefetch_event
.
record
(
GlobalRuntimeInfo
()
.
h2d_stream
)
GlobalRuntimeInfo
.
bwd_prefetch_event_map
[
h2d_rid
]
=
prefetch_event
GlobalRuntimeInfo
()
.
bwd_prefetch_event_map
[
h2d_rid
]
=
prefetch_event
return
grad_output
,
None
,
None
return
grad_output
,
None
,
None
...
@@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
...
@@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
ret
=
SynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
ret
=
SynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
return
ret
return
ret
def
convert_fwd_prefetch_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
def
convert_fwd_prefetch_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
'''
'''
Convert Prefetch and Offload operation into runtime action.
Convert Prefetch and Offload operation into runtime action.
...
@@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
...
@@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
if
fwd_info
or
bwd_info
:
if
fwd_info
or
bwd_info
:
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_upload_bwd_offload_to_action
,
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
))
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
))
replace_node_users
(
last_inp_node
,
new_node
)
replace_node_users
(
last_inp_node
,
new_node
)
...
@@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
...
@@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# upload parameters of the first region
# upload parameters of the first region
last_inp_node
=
tuple
(
mod_graph
.
nodes
)[
0
]
last_inp_node
=
tuple
(
mod_graph
.
nodes
)[
0
]
first_region_with_p
=
[
first_region_with_p
=
[
region
for
region
in
region_list
if
region
.
param_size
][
0
]
region
for
region
in
region_list
if
region
.
param_size
][
0
]
fwd_info
=
{
"h2d_rid"
:
first_region_with_p
.
r_id
}
fwd_info
=
{
"h2d_rid"
:
first_region_with_p
.
r_id
}
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
upload_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_upload_bwd_offload_to_action
,
upload_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
{}))
args
=
(
last_inp_node
,
fwd_info
,
{}))
replace_node_users
(
last_inp_node
,
upload_apply_node
)
replace_node_users
(
last_inp_node
,
upload_apply_node
)
last_inp_node
=
upload_apply_node
last_inp_node
=
upload_apply_node
...
@@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
...
@@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
fwd_info
[
'h2d_rid'
]
=
fwd_prefetch_region
.
r_id
fwd_info
[
'h2d_rid'
]
=
fwd_prefetch_region
.
r_id
# forward offload
# forward offload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
fwd_info
[
'd2h_rid'
]
=
r_idx
-
1
fwd_info
[
'd2h_rid'
]
=
r_idx
-
1
bwd_info
=
{}
bwd_info
=
{}
# backward prefetch
# backward prefetch
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
bwd_info
[
'sync_rid'
]
=
r_idx
-
1
bwd_info
[
'sync_rid'
]
=
r_idx
-
1
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
bwd_prefetch_region
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
bwd_prefetch_region
:
bwd_info
[
'h2d_rid'
]
=
region_list
[
r_idx
-
1
].
bwd_prefetch_region
.
r_id
bwd_info
[
'h2d_rid'
]
=
region_list
[
r_idx
-
1
].
bwd_prefetch_region
.
r_id
if
fwd_info
or
bwd_info
:
if
fwd_info
or
bwd_info
:
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_prefetch_bwd_offload_to_action
,
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
))
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
))
replace_node_users
(
last_inp_node
,
new_node
)
replace_node_users
(
last_inp_node
,
new_node
)
...
@@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
...
@@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
if
region
.
bwd_prefetch_region
:
if
region
.
bwd_prefetch_region
:
bwd_info
=
{
'h2d_rid'
:
region
.
bwd_prefetch_region
.
r_id
}
bwd_info
=
{
'h2d_rid'
:
region
.
bwd_prefetch_region
.
r_id
}
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_prefetch_bwd_offload_to_action
,
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
{},
bwd_info
))
args
=
(
last_inp_node
,
{},
bwd_info
))
replace_node_users
(
last_inp_node
,
new_node
)
replace_node_users
(
last_inp_node
,
new_node
)
# gm.graph.print_tabular()
# gm.graph.print_tabular()
...
...
colossalai/auto_parallel/offload/util.py
View file @
638a07a7
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
from
typing
import
List
import
torch
import
torch
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
from
.region
import
Region
from
.region
import
Region
...
@@ -12,6 +15,7 @@ class NodeInfo:
...
@@ -12,6 +15,7 @@ class NodeInfo:
runtime_fwd_mem
:
float
=
0
runtime_fwd_mem
:
float
=
0
runtime_bwd_mem
:
float
=
0
runtime_bwd_mem
:
float
=
0
class
NvDevicePower
:
class
NvDevicePower
:
"""
"""
NVIDIA GPU computing performance (TFLOPs).
NVIDIA GPU computing performance (TFLOPs).
...
@@ -30,12 +34,14 @@ class NvDevicePower:
...
@@ -30,12 +34,14 @@ class NvDevicePower:
A100_FP32
=
19.5
A100_FP32
=
19.5
class
GlobalRuntimeInfo
:
class
GlobalRuntimeInfo
(
metaclass
=
SingletonMeta
):
h2d_stream
=
torch
.
cuda
.
Stream
()
d2h_stream
=
torch
.
cuda
.
Stream
()
def
__init__
(
self
):
fwd_prefetch_event_map
=
{}
self
.
h2d_stream
=
torch
.
cuda
.
Stream
()
bwd_prefetch_event_map
=
{}
self
.
d2h_stream
=
torch
.
cuda
.
Stream
()
region_list
=
[]
self
.
fwd_prefetch_event_map
=
{}
self
.
bwd_prefetch_event_map
=
{}
self
.
region_list
=
[]
def
compute_act_peak_mem
(
region_list
:
List
[
Region
])
->
float
:
def
compute_act_peak_mem
(
region_list
:
List
[
Region
])
->
float
:
...
@@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
...
@@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
return
act_peak_mem
return
act_peak_mem
def
compute_max_param_mem
(
region_list
:
List
[
Region
])
->
float
:
def
compute_max_param_mem
(
region_list
:
List
[
Region
])
->
float
:
return
max
(
region
.
param_size
for
region
in
region_list
)
return
max
(
region
.
param_size
for
region
in
region_list
)
def
compute_total_param_mem
(
region_list
:
List
[
Region
])
->
float
:
def
compute_total_param_mem
(
region_list
:
List
[
Region
])
->
float
:
return
sum
(
region
.
param_size
for
region
in
region_list
if
region
.
r_id
<=
region
.
shared_rid
)
return
sum
(
region
.
param_size
for
region
in
region_list
if
region
.
r_id
<=
region
.
shared_rid
)
def
requires_upload_p_in_fwd
(
shared_reg
:
Region
):
def
requires_upload_p_in_fwd
(
shared_reg
:
Region
):
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
and
shared_reg
.
need_offload
)
def
requires_release_p_in_bwd
(
shared_reg
:
Region
):
def
requires_release_p_in_bwd
(
shared_reg
:
Region
):
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
and
shared_reg
.
need_offload
)
def
requires_offload_g_in_bwd
(
region
:
Region
):
def
requires_offload_g_in_bwd
(
region
:
Region
):
return
region
.
param_size
and
(
region
.
r_id
<=
region
.
shared_rid
)
return
region
.
param_size
and
(
region
.
r_id
<=
region
.
shared_rid
)
tests/test_auto_parallel/test_offload/test_perf.py
View file @
638a07a7
import
time
import
time
import
pytest
from
functools
import
partial
from
functools
import
partial
import
pytest
import
torch
import
torch
from
torch.utils._pytree
import
tree_map
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.utils._pytree
import
tree_map
import
colossalai
import
colossalai
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.fx.profiler
import
parameter_size
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.nn.parallel
import
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.auto_parallel.offload.amp_optimizer
import
AMPOptimizer
from
colossalai.auto_parallel.offload.amp_optimizer
import
AMPOptimizer
from
colossalai.auto_parallel.offload.mem_optimize
import
memory_optimize
from
colossalai.auto_parallel.offload.mem_optimize
import
memory_optimize
from
colossalai.auto_parallel.offload.solver
import
NOT_NVML
from
colossalai.auto_parallel.offload.solver
import
NOT_NVML
from
colossalai.fx.profiler
import
parameter_size
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.parallel
import
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
,
get_current_device
from
tests.test_tensor.common_utils
import
set_seed
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.test_auto_parallel.test_offload.model_utils
import
*
from
tests.test_auto_parallel.test_offload.model_utils
import
*
from
tests.test_tensor.common_utils
import
set_seed
@
parameterize
(
'model_name'
,
[
'gpt2_'
])
@
parameterize
(
'model_name'
,
[
'gpt2_'
])
@
parameterize
(
'memory_budget'
,
[
5000
])
@
parameterize
(
'memory_budget'
,
[
5000
])
@
parameterize
(
'solver_name'
,
[
'asyn'
])
@
parameterize
(
'solver_name'
,
[
'asyn'
])
def
exam_fwd_bwd
(
def
exam_fwd_bwd
(
model_name
:
str
,
memory_budget
:
float
,
solver_name
:
str
):
model_name
:
str
,
memory_budget
:
float
,
solver_name
:
str
):
# build model
# build model
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
data_gen
=
get_components_func
()
model_builder
,
data_gen
=
get_components_func
()
label
=
torch
.
randint
(
low
=
0
,
high
=
128
,
size
=
(
64
,
8
,),
device
=
get_current_device
())
label
=
torch
.
randint
(
low
=
0
,
high
=
128
,
size
=
(
64
,
8
,
),
device
=
get_current_device
())
criterion
=
LMLoss
()
criterion
=
LMLoss
()
set_seed
(
42
)
set_seed
(
42
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
model
=
model_builder
()
model
=
model_builder
()
model
.
train
()
model
.
train
()
param_size
=
parameter_size
(
model
)
/
1024
**
2
/
2
param_size
=
parameter_size
(
model
)
/
1024
**
2
/
2
init_time
=
time
.
time
()
-
start_time
init_time
=
time
.
time
()
-
start_time
print
(
f
"init_param_size=
{
param_size
:.
3
f
}
MB | init_model_time=
{
init_time
:.
3
f
}
s"
)
print
(
f
"init_param_size=
{
param_size
:.
3
f
}
MB | init_model_time=
{
init_time
:.
3
f
}
s"
)
...
@@ -92,13 +90,11 @@ def exam_fwd_bwd(
...
@@ -92,13 +90,11 @@ def exam_fwd_bwd(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
exec_time
=
sum
(
sorted
(
time_list
)[:
5
])
/
5
exec_time
=
sum
(
sorted
(
time_list
)[:
5
])
/
5
runtime_peak_mem_alc
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
runtime_peak_mem_alc
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
runtime_peak_mem_res
=
torch
.
cuda
.
max_memory_reserved
()
/
1024
**
2
runtime_peak_mem_res
=
torch
.
cuda
.
max_memory_reserved
()
/
1024
**
2
print
(
f
'gemini | model_name:
{
model_name
}
'
)
print
(
f
'gemini | model_name:
{
model_name
}
'
)
print
(
print
(
f
'| exec_time=
{
exec_time
:.
3
f
}
s | param_size=
{
param_size
:.
3
f
}
MB '
f
'| exec_time=
{
exec_time
:.
3
f
}
s | param_size=
{
param_size
:.
3
f
}
MB '
f
'| runtime_peak_mem_alc=
{
runtime_peak_mem_alc
:.
3
f
}
MB| runtime_peak_mem_res=
{
runtime_peak_mem_res
:.
3
f
}
MB|'
)
f
'| runtime_peak_mem_alc=
{
runtime_peak_mem_alc
:.
3
f
}
MB| runtime_peak_mem_res=
{
runtime_peak_mem_res
:.
3
f
}
MB|'
)
print
(
time_list
)
print
(
time_list
)
del
data_args
del
data_args
...
@@ -129,22 +125,26 @@ def exam_fwd_bwd(
...
@@ -129,22 +125,26 @@ def exam_fwd_bwd(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
exec_time
=
sum
(
sorted
(
time_list
)[:
5
])
/
5
exec_time
=
sum
(
sorted
(
time_list
)[:
5
])
/
5
runtime_peak_mem_alc
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
runtime_peak_mem_alc
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
runtime_peak_mem_res
=
torch
.
cuda
.
max_memory_reserved
()
/
1024
**
2
runtime_peak_mem_res
=
torch
.
cuda
.
max_memory_reserved
()
/
1024
**
2
print
(
f
'solver_name:
{
solver_name
}
| model_name:
{
model_name
}
'
)
print
(
f
'solver_name:
{
solver_name
}
| model_name:
{
model_name
}
'
)
print
(
print
(
f
'| exec_time=
{
exec_time
:.
3
f
}
s | param_size=
{
param_size
:.
3
f
}
MB '
f
'| exec_time=
{
exec_time
:.
3
f
}
s | param_size=
{
param_size
:.
3
f
}
MB '
f
'| runtime_peak_mem_alc=
{
runtime_peak_mem_alc
:.
3
f
}
MB| runtime_peak_mem_res=
{
runtime_peak_mem_res
:.
3
f
}
MB|'
)
f
'| runtime_peak_mem_alc=
{
runtime_peak_mem_alc
:.
3
f
}
MB| runtime_peak_mem_res=
{
runtime_peak_mem_res
:.
3
f
}
MB|'
)
print
(
time_list
)
print
(
time_list
)
@
pytest
.
mark
.
skipif
(
NOT_NVML
,
reason
=
'pynvml is not installed'
)
def
test_perf
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_fwd_bwd
()
exam_fwd_bwd
()
if
__name__
==
'__main__'
:
@
pytest
.
mark
.
skip
(
"this test failed"
)
run_func
=
partial
(
test_perf
,
world_size
=
1
,
port
=
free_port
())
@
pytest
.
mark
.
skipif
(
NOT_NVML
,
reason
=
'pynvml is not installed'
)
def
test_perf
():
run_func
=
partial
(
run_dist
,
world_size
=
1
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
1
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
'__main__'
:
test_perf
()
tests/test_booster/test_plugin/test_gemini_plugin.py
View file @
638a07a7
...
@@ -21,9 +21,6 @@ def check_gemini_plugin(early_stop: bool = True):
...
@@ -21,9 +21,6 @@ def check_gemini_plugin(early_stop: bool = True):
Args:
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
"""
plugin
=
GeminiPlugin
(
placement_policy
=
'cuda'
,
strict_ddp_mode
=
True
,
max_norm
=
1.0
,
initial_scale
=
2
**
5
)
booster
=
Booster
(
plugin
=
plugin
)
passed_models
=
[]
passed_models
=
[]
failed_info
=
{}
# (model_name, error) pair
failed_info
=
{}
# (model_name, error) pair
...
@@ -34,46 +31,23 @@ def check_gemini_plugin(early_stop: bool = True):
...
@@ -34,46 +31,23 @@ def check_gemini_plugin(early_stop: bool = True):
continue
continue
# These models are not compatible with gemini
# These models are not compatible with gemini
if
name
in
[
if
name
in
[
'diffusers_clip_vision_model'
,
'diffusers_clip_vision_model'
,
'timm_resnet'
,
'timm_beit'
,
'timm_beitv2'
,
'timm_eca_nfnet'
,
'timm_resnet'
,
'timm_efficientformer'
,
'timm_hrnet_w18_small'
,
'timm_nf_ecaresnet101'
,
'timm_nf_regnet_b0'
,
'timm_beit'
,
'timm_skresnet18'
,
'timm_wide_resnet50_2'
,
'timm_convit'
,
'timm_dm_nfnet'
,
'timm_swin_transformer'
,
'timm_beitv2'
,
'torchaudio_conformer'
,
'torchaudio_deepspeech'
,
'torchaudio_wavernn'
,
'torchaudio_tacotron'
,
'timm_eca_nfnet'
,
'deepfm_interactionarch'
,
'deepfm_simpledeepfmnn'
,
'dlrm'
,
'dlrm_interactionarch'
,
'timm_efficientformer'
,
'torchvision_googlenet'
,
'torchvision_inception_v3'
,
'torchvision_mobilenet_v3_small'
,
'timm_hrnet_w18_small'
,
'torchvision_resnet18'
,
'torchvision_resnext50_32x4d'
,
'torchvision_wide_resnet50_2'
,
'timm_nf_ecaresnet101'
,
'torchvision_vit_b_16'
,
'torchvision_convnext_base'
,
'torchvision_swin_s'
,
'transformers_albert'
,
'timm_nf_regnet_b0'
,
'transformers_albert_for_pretraining'
,
'transformers_bert'
,
'transformers_bert_for_pretraining'
,
'timm_skresnet18'
,
'transformers_gpt_double_heads'
,
'torchaudio_hubert_base'
,
'torchaudio_wav2vec2_base'
,
'timm_wide_resnet50_2'
,
'transformers_t5_for_conditional_generation'
,
'transformers_t5'
,
'transformers_t5_encoder_model'
'timm_convit'
,
'timm_dm_nfnet'
,
'timm_swin_transformer'
,
'torchaudio_conformer'
,
'torchaudio_deepspeech'
,
'torchaudio_wavernn'
,
'torchaudio_tacotron'
,
'deepfm_interactionarch'
,
'deepfm_simpledeepfmnn'
,
'dlrm'
,
'dlrm_interactionarch'
,
'torchvision_googlenet'
,
'torchvision_inception_v3'
,
'torchvision_mobilenet_v3_small'
,
'torchvision_resnet18'
,
'torchvision_resnext50_32x4d'
,
'torchvision_wide_resnet50_2'
,
'torchvision_vit_b_16'
,
'torchvision_convnext_base'
,
'torchvision_swin_s'
,
'transformers_albert'
,
'transformers_albert_for_pretraining'
,
'transformers_bert'
,
'transformers_bert_for_pretraining'
,
'transformers_gpt_double_heads'
,
'torchaudio_hubert_base'
,
]:
]:
continue
continue
try
:
try
:
plugin
=
GeminiPlugin
(
placement_policy
=
'cuda'
,
strict_ddp_mode
=
True
,
max_norm
=
1.0
,
initial_scale
=
2
**
5
)
booster
=
Booster
(
plugin
=
plugin
)
model
=
model_fn
()
model
=
model_fn
()
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
criterion
=
lambda
x
:
x
.
mean
()
criterion
=
lambda
x
:
x
.
mean
()
...
@@ -97,10 +71,15 @@ def check_gemini_plugin(early_stop: bool = True):
...
@@ -97,10 +71,15 @@ def check_gemini_plugin(early_stop: bool = True):
booster
.
backward
(
loss
,
optimizer
)
booster
.
backward
(
loss
,
optimizer
)
optimizer
.
step
()
optimizer
.
step
()
passed_models
.
append
(
name
)
passed_models
.
append
(
name
)
del
booster
,
plugin
,
model
,
optimizer
,
criterion
,
data
,
output
,
loss
except
Exception
as
e
:
except
Exception
as
e
:
failed_info
[
name
]
=
e
failed_info
[
name
]
=
e
if
early_stop
:
if
early_stop
:
raise
e
raise
e
torch
.
cuda
.
empty_cache
()
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
print
(
f
'Passed models(
{
len
(
passed_models
)
}
):
{
passed_models
}
\n\n
'
)
print
(
f
'Passed models(
{
len
(
passed_models
)
}
):
{
passed_models
}
\n\n
'
)
print
(
f
'Failed models(
{
len
(
failed_info
)
}
):
{
list
(
failed_info
.
keys
())
}
\n\n
'
)
print
(
f
'Failed models(
{
len
(
failed_info
)
}
):
{
list
(
failed_info
.
keys
())
}
\n\n
'
)
...
@@ -138,7 +117,6 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
...
@@ -138,7 +117,6 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
check_gemini_plugin
(
early_stop
=
early_stop
)
check_gemini_plugin
(
early_stop
=
early_stop
)
@
pytest
.
mark
.
skip
(
reason
=
'Skip gemini plugin test due to OOM'
)
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_gemini_plugin
(
early_stop
:
bool
=
True
):
def
test_gemini_plugin
(
early_stop
:
bool
=
True
):
world_size
=
2
world_size
=
2
...
...
tests/test_zero/low_level_zero/test_zero1_2.py
View file @
638a07a7
...
@@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
import
colossalai
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing.random
import
seed_all
from
colossalai.testing.random
import
seed_all
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero
import
LowLevelZeroOptimizer
from
colossalai.zero
import
LowLevelZeroOptimizer
...
@@ -176,6 +177,7 @@ def run_dist(rank, world_size, port):
...
@@ -176,6 +177,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_zero_1_2
():
def
test_zero_1_2
():
world_size
=
2
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment