Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8213f89f
Unverified
Commit
8213f89f
authored
Feb 13, 2023
by
HELSON
Committed by
GitHub
Feb 13, 2023
Browse files
[gemini] add fake_release_chunk for keep-gathered chunk in the inference mode (#2671)
parent
09660088
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
16 deletions
+43
-16
colossalai/gemini/chunk/manager.py
colossalai/gemini/chunk/manager.py
+8
-0
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+5
-2
tests/test_gemini/update/test_inference.py
tests/test_gemini/update/test_inference.py
+30
-14
No files found.
colossalai/gemini/chunk/manager.py
View file @
8213f89f
...
...
@@ -140,6 +140,14 @@ class ChunkManager:
self
.
__add_memory_usage
(
chunk
.
memory_usage
)
return
True
def
fake_release_chunk
(
self
,
chunk
:
Chunk
)
->
None
:
"""Release gathered chunk in a fake mode.
This function is used for keep-gathered chunk in the inference mode.
"""
assert
chunk
.
keep_gathered
assert
chunk
.
tensor_state_cnter
[
TensorState
.
HOLD
]
==
chunk
.
num_tensors
self
.
__sub_accessed_chunk
(
chunk
)
def
copy_tensor_to_chunk_slice
(
self
,
tensor
:
torch
.
Tensor
,
data
:
torch
.
Tensor
)
->
None
:
"""
Copy data to the chunk.
...
...
colossalai/nn/parallel/data_parallel.py
View file @
8213f89f
...
...
@@ -257,6 +257,9 @@ class ZeroDDP(ColoDDP):
access_list
=
list
(
self
.
chunk_manager
.
accessed_chunks
)
# we need to scatter all accessed chunks and move them to their original places
for
chunk
in
access_list
:
if
chunk
.
keep_gathered
:
self
.
chunk_manager
.
fake_release_chunk
(
chunk
)
else
:
assert
chunk
.
can_release
self
.
chunk_manager
.
release_chunk
(
chunk
)
first_param
=
next
(
iter
(
chunk
.
tensors_info
))
...
...
tests/test_gemini/update/test_inference.py
View file @
8213f89f
from
functools
import
partial
from
typing
import
Callable
import
pytest
import
torch
...
...
@@ -13,7 +14,7 @@ from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chu
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
,
zero_model_wrapper
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
...
...
@@ -36,9 +37,35 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
4e-3
)
def
multi_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_policy
:
str
):
world_size
=
dist
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
return
model
def
single_chunk_init
(
model
:
torch
.
nn
.
Module
,
placement_policy
:
str
):
gemini_config
=
dict
(
device
=
get_current_device
(),
placement_policy
=
placement_policy
,
pin_memory
=
True
,
)
model
=
zero_model_wrapper
(
model
=
model
,
zero_stage
=
3
,
gemini_config
=
gemini_config
)
return
model
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
])
def
exam_inference
(
placement_policy
,
model_name
:
str
):
@
parameterize
(
'model_init_func'
,
[
single_chunk_init
,
multi_chunk_init
])
def
exam_inference
(
placement_policy
:
str
,
model_name
:
str
,
model_init_func
:
Callable
):
set_seed
(
19360226
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
...
@@ -56,18 +83,7 @@ def exam_inference(placement_policy, model_name: str):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
*
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
model
=
model_init_func
(
model
,
placement_policy
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
128
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment