Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bb4e9a31
Unverified
Commit
bb4e9a31
authored
Jan 11, 2023
by
HELSON
Committed by
GitHub
Jan 11, 2023
Browse files
[zero] add inference mode and its unit test (#2418)
parent
63be79d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
6 deletions
+157
-6
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+12
-6
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+23
-0
tests/test_gemini/update/test_inference.py
tests/test_gemini/update/test_inference.py
+122
-0
No files found.
colossalai/gemini/gemini_mgr.py
View file @
bb4e9a31
...
@@ -50,6 +50,17 @@ class GeminiManager:
...
@@ -50,6 +50,17 @@ class GeminiManager:
self
.
_warmup
=
True
self
.
_warmup
=
True
self
.
_comp_cuda_demand_time
=
0
self
.
_comp_cuda_demand_time
=
0
def
reset_attributes
(
self
):
self
.
_compute_idx
=
-
1
self
.
_h2d_volume
=
0
self
.
_d2h_volume
=
0
self
.
_layout_time
=
0
self
.
_evict_time
=
0
self
.
_comp_cuda_demand_time
=
0
def
is_warmup
(
self
):
return
self
.
_warmup
def
memstats
(
self
):
def
memstats
(
self
):
"""memstats
"""memstats
...
@@ -73,12 +84,7 @@ class GeminiManager:
...
@@ -73,12 +84,7 @@ class GeminiManager:
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
if
self
.
_mem_stats_collector
and
self
.
_warmup
:
self
.
_mem_stats_collector
.
finish_collection
()
self
.
_mem_stats_collector
.
finish_collection
()
self
.
_warmup
=
False
self
.
_warmup
=
False
self
.
_compute_idx
=
-
1
self
.
reset_attributes
()
self
.
_h2d_volume
=
0
self
.
_d2h_volume
=
0
self
.
_layout_time
=
0
self
.
_evict_time
=
0
self
.
_comp_cuda_demand_time
=
0
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
def
adjust_layout
(
self
,
chunks
:
Tuple
[
Chunk
,
...])
->
None
:
""" Adjust the layout of stateful tensors according to the information provided
""" Adjust the layout of stateful tensors according to the information provided
...
...
colossalai/nn/parallel/data_parallel.py
View file @
bb4e9a31
...
@@ -268,12 +268,35 @@ class ZeroDDP(ColoDDP):
...
@@ -268,12 +268,35 @@ class ZeroDDP(ColoDDP):
self
.
_logger
=
get_dist_logger
()
self
.
_logger
=
get_dist_logger
()
def
_post_forward
(
self
):
"""This function is only triggered for inference.
"""
access_list
=
list
(
self
.
chunk_manager
.
accessed_chunks
)
# we need to scatter all accessed chunks and move them to their original places
for
chunk
in
access_list
:
assert
chunk
.
can_release
self
.
chunk_manager
.
release_chunk
(
chunk
)
first_param
=
next
(
iter
(
chunk
.
tensors_info
))
self
.
chunk_manager
.
move_chunk
(
chunk
,
self
.
grads_device
[
first_param
])
assert
self
.
chunk_manager
.
accessed_mem
==
0
# reset all recorded attributes
self
.
gemini_manager
.
reset_attributes
()
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
# check whether we are in a inference mode
grad_flag
=
torch
.
is_grad_enabled
()
if
not
grad_flag
:
assert
not
self
.
gemini_manager
.
is_warmup
(),
"You should run a completed iteration as your warmup iter"
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
args
,
kwargs
=
_cast_float
(
args
,
torch
.
half
),
_cast_float
(
kwargs
,
torch
.
half
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
gemini_manager
.
pre_iter
(
*
args
)
self
.
gemini_manager
.
pre_iter
(
*
args
)
with
ColoParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
with
ColoParamOpHookManager
.
use_hooks
(
self
.
param_op_hook
):
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
# scatter chunks in the inference mode
if
not
grad_flag
:
self
.
_post_forward
()
if
self
.
force_outputs_fp32
:
if
self
.
force_outputs_fp32
:
return
_cast_float
(
outputs
,
torch
.
float
)
return
_cast_float
(
outputs
,
torch
.
float
)
return
outputs
return
outputs
...
...
tests/test_gemini/update/test_inference.py
0 → 100644
View file @
bb4e9a31
from
functools
import
partial
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.chunk
import
ChunkManager
,
init_chunk_manager
,
search_chunk_configuration
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
,
post_process_colo_init_ctx
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
def
check_param
(
model
:
ZeroDDP
,
torch_model
:
torch
.
nn
.
Module
):
zero_dict
=
model
.
state_dict
(
only_rank_0
=
False
)
torch_dict
=
torch_model
.
state_dict
()
for
key
,
value
in
torch_dict
.
items
():
# key is 'module.model.PARAMETER', so we truncate it
key
=
key
[
7
:]
assert
key
in
zero_dict
,
"{} not in ZeRO dictionary."
.
format
(
key
)
temp_zero_value
=
zero_dict
[
key
].
to
(
device
=
value
.
device
,
dtype
=
value
.
dtype
)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
4e-3
)
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
])
def
exam_inference
(
placement_policy
,
model_name
:
str
):
set_seed
(
19360226
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
128
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
init_dev
=
get_current_device
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
world_size
=
torch
.
distributed
.
get_world_size
()
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
config_dict
[
world_size
][
'chunk_size'
]
=
5000
config_dict
[
world_size
][
'keep_gathered'
]
=
False
if
placement_policy
!=
'cuda'
:
init_device
=
torch
.
device
(
'cpu'
)
else
:
init_device
=
None
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
init_device
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pin_memory
=
True
)
optimizer
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
zero_optim
=
ZeroOptimizer
(
optimizer
,
model
,
initial_scale
=
128
)
model
.
eval
()
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
train_dataloader
=
iter
(
train_dataloader
)
def
train_iter
():
input_ids
,
label
=
next
(
train_dataloader
)
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
zero_optim
.
zero_grad
()
torch_optim
.
zero_grad
()
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert_close
(
torch_loss
,
loss
)
zero_optim
.
step
()
torch_optim
.
step
()
check_param
(
model
,
torch_model
)
def
inference_iter
():
input_ids
,
label
=
next
(
train_dataloader
)
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
with
torch
.
no_grad
():
torch_output
=
torch_model
(
input_ids
)
torch_loss
=
criterion
(
torch_output
.
float
(),
label
)
zero_output
=
model
(
input_ids
)
zero_loss
=
criterion
(
zero_output
.
float
(),
label
)
assert_close
(
torch_loss
,
zero_loss
)
train_iter
()
inference_iter
()
train_iter
()
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
exam_inference
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_inference
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_inference
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment