Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8daf1b4d
Unverified
Commit
8daf1b4d
authored
Nov 25, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 25, 2022
Browse files
[Gemini] patch for supporting orch.add_ function for ColoTensor (#2003)
parent
632753ab
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
59 additions
and
94 deletions
+59
-94
colossalai/gemini/ophooks/param_trace_hook.py
colossalai/gemini/ophooks/param_trace_hook.py
+0
-81
colossalai/nn/_ops/__init__.py
colossalai/nn/_ops/__init__.py
+6
-5
colossalai/nn/_ops/batch_norm.py
colossalai/nn/_ops/batch_norm.py
+33
-0
colossalai/nn/_ops/element_wise.py
colossalai/nn/_ops/element_wise.py
+12
-0
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+1
-1
tests/components_to_test/inline_op_model.py
tests/components_to_test/inline_op_model.py
+3
-3
tests/test_gemini/test_gemini_train.py
tests/test_gemini/test_gemini_train.py
+4
-4
No files found.
colossalai/gemini/ophooks/param_trace_hook.py
deleted
100644 → 0
View file @
632753ab
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
partial
from
typing
import
List
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.tensor.param_op_hook
import
ParamOpHook
class
TrainingPhase
(
Enum
):
FORWARD
=
0
BACKWARD
=
1
class
ParamMemHook
(
ParamOpHook
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
_training_phase
=
TrainingPhase
.
FORWARD
self
.
mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_non_model_data_list
=
[]
self
.
_model_data_list
=
[]
def
_move_params_to_dev
(
self
,
params
,
dev
:
str
)
->
int
:
assert
isinstance
(
dev
,
str
),
f
"device should be a str not torch.device"
comm_volume
=
0
for
p
in
params
:
if
p
.
data
.
device
.
type
!=
dev
:
p
.
data
=
p
.
data
.
to
(
dev
)
comm_volume
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
if
p
.
grad
is
not
None
:
if
p
.
grad
.
device
.
type
!=
dev
:
p
.
grad
=
p
.
grad
.
to
(
dev
)
comm_volume
+=
p
.
grad
.
numel
()
*
p
.
grad
.
element_size
()
return
comm_volume
def
sample_model_data
(
self
,
params
):
data_volume
=
0
for
p
in
params
:
data_volume
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
if
self
.
_training_phase
==
TrainingPhase
.
BACKWARD
:
# add param.grad, actually param.grad is None in this time
data_volume
*=
2
self
.
_model_data_list
.
append
(
data_volume
)
def
pre_op
(
self
,
params
):
cuda_volume
=
self
.
mem_monitor
.
finish
()
if
len
(
self
.
_model_data_list
):
self
.
_non_model_data_list
.
append
(
cuda_volume
-
self
.
_model_data_list
[
-
1
])
self
.
_move_params_to_dev
(
params
,
'cuda'
)
self
.
sample_model_data
(
params
)
self
.
mem_monitor
.
start
()
def
post_op
(
self
,
params
):
self
.
_move_params_to_dev
(
params
,
'cpu'
)
def
pre_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
def
post_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
post_op
(
params
)
def
pre_backward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
def
post_backward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
post_op
(
params
)
@
contextmanager
def
switch_training_phase
(
self
,
training_phase
:
TrainingPhase
=
TrainingPhase
.
BACKWARD
):
old_training_phase
=
self
.
_training_phase
try
:
self
.
_training_phase
=
training_phase
yield
finally
:
self
.
_training_phase
=
old_training_phase
switch_to_backward
=
switch_training_phase
switch_to_forward
=
partial
(
switch_to_backward
,
training_phase
=
TrainingPhase
.
FORWARD
)
\ No newline at end of file
colossalai/nn/_ops/__init__.py
View file @
8daf1b4d
from
.linear
import
colo_linear
from
.addmm
import
colo_addmm
from
.batch_norm
import
colo_batch_norm
from
.element_wise
import
*
from
.element_wise
import
*
from
.layernorm
import
colo_layernorm
from
.loss
import
colo_cross_entropy
from
.embedding
import
colo_embedding
from
.embedding
import
colo_embedding
from
.addmm
import
colo_addmm
from
.embedding_bag
import
colo_embedding_bag
from
.embedding_bag
import
colo_embedding_bag
from
.layernorm
import
colo_layernorm
from
.linear
import
colo_linear
from
.loss
import
colo_cross_entropy
from
.view
import
colo_view
from
.view
import
colo_view
colossalai/nn/_ops/batch_norm.py
0 → 100644
View file @
8daf1b4d
from
typing
import
Optional
import
torch.nn.functional
as
F
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ReplicaSpec
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
@
colo_op_impl
(
F
.
batch_norm
)
def
colo_batch_norm
(
input
:
GeneralTensor
,
running_mean
:
Optional
[
GeneralTensor
],
running_var
:
Optional
[
GeneralTensor
],
weight
:
Optional
[
GeneralTensor
]
=
None
,
bias
:
Optional
[
GeneralTensor
]
=
None
,
training
:
bool
=
False
,
momentum
:
float
=
0.1
,
eps
:
float
=
1e-5
,
):
assert
isinstance
(
weight
,
ColoTensor
)
running_mean
=
running_mean
.
detach
()
running_var
=
running_var
.
detach
()
input
=
convert_to_colo_tensor
(
input
,
weight
.
get_process_group
())
bias
=
convert_to_colo_tensor
(
bias
,
weight
.
get_process_group
())
input
=
input
.
redistribute
(
ReplicaSpec
())
bias
=
bias
.
redistribute
(
ReplicaSpec
())
output
=
F
.
batch_norm
(
input
,
running_mean
,
running_var
,
weight
,
bias
,
training
,
momentum
,
eps
)
output
=
ColoTensor
.
from_torch_tensor
(
tensor
=
output
,
spec
=
ColoTensorSpec
(
pg
=
weight
.
get_process_group
()))
return
output
colossalai/nn/_ops/element_wise.py
View file @
8daf1b4d
...
@@ -34,6 +34,18 @@ def register_elementwise_op(op):
...
@@ -34,6 +34,18 @@ def register_elementwise_op(op):
dist_attr
=
input_tensor
.
dist_spec
))
dist_attr
=
input_tensor
.
dist_spec
))
@
colo_op_impl
(
torch
.
relu_
)
def
elementwise_op
(
input_tensor
):
torch
.
relu_
(
input_tensor
.
data
)
return
input_tensor
@
colo_op_impl
(
Tensor
.
add_
)
def
elementwise_op
(
input_tensor
:
ColoTensor
,
*
args
,
**
kwargs
):
input_tensor
=
input_tensor
.
data
.
add_
(
*
args
,
**
kwargs
)
return
input_tensor
# Tensor op
# Tensor op
register_elementwise_op
(
Tensor
.
abs
)
register_elementwise_op
(
Tensor
.
abs
)
register_elementwise_op
(
Tensor
.
absolute
)
register_elementwise_op
(
Tensor
.
absolute
)
...
...
colossalai/nn/parallel/data_parallel.py
View file @
8daf1b4d
...
@@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
...
@@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
p
.
grad
=
None
p
.
grad
=
None
def
_post_backward
(
self
):
def
_post_backward
(
self
):
assert
self
.
chunk_manager
.
accessed_mem
==
0
#
assert self.chunk_manager.accessed_mem == 0
self
.
_setup_grads_ptr
()
self
.
_setup_grads_ptr
()
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
'comp cuda demand time:
{
self
.
gemini_manager
.
_comp_cuda_demand_time
}
, layout time:
{
self
.
gemini_manager
.
_layout_time
}
, evict time:
{
self
.
gemini_manager
.
_evict_time
}
, CPU->CUDA vol:
{
self
.
gemini_manager
.
_h2d_volume
}
B, CUDA->CPU vol:
{
self
.
gemini_manager
.
_d2h_volume
}
'
f
'comp cuda demand time:
{
self
.
gemini_manager
.
_comp_cuda_demand_time
}
, layout time:
{
self
.
gemini_manager
.
_layout_time
}
, evict time:
{
self
.
gemini_manager
.
_evict_time
}
, CPU->CUDA vol:
{
self
.
gemini_manager
.
_h2d_volume
}
B, CUDA->CPU vol:
{
self
.
gemini_manager
.
_d2h_volume
}
'
...
...
tests/components_to_test/inline_op_model.py
View file @
8daf1b4d
...
@@ -16,14 +16,14 @@ class InlineOpModule(CheckpointModule):
...
@@ -16,14 +16,14 @@ class InlineOpModule(CheckpointModule):
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
(
checkpoint
=
checkpoint
)
super
().
__init__
(
checkpoint
=
checkpoint
)
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
8
,
8
))
self
.
proj2
=
nn
.
Linear
(
8
,
8
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
self
.
proj1
(
x
)
# inline add_
# inline add_
x
.
add_
(
10
)
x
.
add_
(
10
)
x
=
F
.
linear
(
x
,
self
.
weight
)
x
=
self
.
proj2
(
x
)
# inline relu_
# inline relu_
x
=
torch
.
relu_
(
x
)
x
=
torch
.
relu_
(
x
)
x
=
self
.
proj2
(
x
)
x
=
self
.
proj2
(
x
)
...
...
tests/test_gemini/test_gemini_train.py
View file @
8daf1b4d
...
@@ -15,7 +15,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -15,7 +15,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def
run_gemini_fwd_bwd
(
rank
,
world_size
,
port
,
model_name
:
str
,
iter_num
=
2
):
def
run_gemini_fwd_bwd
(
rank
,
world_size
,
port
,
model_name
:
str
,
iter_num
=
2
):
PLACEMENT_POLICY
=
'
cuda
'
PLACEMENT_POLICY
=
'
auto
'
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
@@ -52,9 +52,9 @@ def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2):
...
@@ -52,9 +52,9 @@ def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2):
print
(
f
'pass test
{
model_name
}
'
)
print
(
f
'pass test
{
model_name
}
'
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
'bert'
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"inline_op_model"
,
"bert"
,
"simple_net"
,
"gpt2"
,
"resnet18"
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_gemini_train
(
model_name
,
iter_num
=
2
):
def
test_gemini_train
(
model_name
,
iter_num
=
4
):
run_func
=
partial
(
run_gemini_fwd_bwd
,
world_size
=
1
,
port
=
free_port
(),
model_name
=
model_name
,
iter_num
=
iter_num
)
run_func
=
partial
(
run_gemini_fwd_bwd
,
world_size
=
1
,
port
=
free_port
(),
model_name
=
model_name
,
iter_num
=
iter_num
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
...
@@ -63,5 +63,5 @@ if __name__ == '__main__':
...
@@ -63,5 +63,5 @@ if __name__ == '__main__':
# for model_name in ["bert", "resnet18", "inline_op_model"]:
# for model_name in ["bert", "resnet18", "inline_op_model"]:
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
# repeated_computed_layer, resnet, simple_net
# repeated_computed_layer, resnet, simple_net
for
model_name
in
[
"
n
es
ted_model"
,
"no_leaf_module
"
]:
for
model_name
in
[
"
r
es
net18
"
]:
test_gemini_train
(
model_name
=
model_name
,
iter_num
=
4
)
test_gemini_train
(
model_name
=
model_name
,
iter_num
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment