Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
676f1915
Unverified
Commit
676f1915
authored
Apr 28, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 28, 2022
Browse files
[Tensor] activation is an attr of ColoTensor (#897)
parent
e76f76c0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
35 deletions
+51
-35
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+13
-14
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+25
-16
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+8
-3
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+5
-2
No files found.
colossalai/tensor/_ops/linear.py
View file @
676f1915
...
...
@@ -9,17 +9,19 @@ from packaging import version
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
# Input:S[1]
if
input_tensor
.
is_gathered
():
# Not splited yet.
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
(
-
1
),
\
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
(
-
1
)
*
gpc
.
tensor_parallel_size
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
.
torch_tensor
(),
parallel_action
.
parallel_mode
,
dim
=-
1
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
.
torch_tensor
(),
parallel_action
.
parallel_mode
,
dim
=-
1
)
elif
input_tensor
.
shard_pattern
==
ShardPattern
.
Col
:
# Splited by 1Dcol
assert
input_tensor
.
shape
[
-
1
]
==
weight
.
size
(
-
1
),
\
...
...
@@ -40,7 +42,8 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTen
output
=
ColoTensor
.
init_from_torch_tensor
(
output
)
return
output
def
colo_linear_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
def
colo_linear_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
...
...
@@ -59,14 +62,9 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTen
'Invalid bias spec for 1Dcol Linear op'
output_parallel
=
torch
.
nn
.
functional
.
linear
(
input_parallel
,
weight
.
torch_tensor
(),
bias
.
torch_tensor
())
output
=
ColoTensor
.
init_from_torch_tensor
(
output_parallel
)
out_parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
Activation
,
parallel_mode
=
parallel_action
.
parallel_mode
)
]
out_parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
parallel_mode
=
parallel_action
.
parallel_mode
)]
output_spec
=
TensorSpec
(
out_parallel_action_list
)
output
.
set_spec
(
output_spec
,
shard
=
False
)
output
.
set_shard_pattern
(
ShardPattern
.
Col
)
...
...
@@ -75,6 +73,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTen
output
.
gather
()
return
output
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
...
...
@@ -99,15 +98,15 @@ def colo_linear(types, args, kwargs, pg):
if
bias
is
not
None
and
not
isinstance
(
bias
,
ColoTensor
):
bias
=
ColoTensor
.
init_from_torch_tensor
(
bias
)
# Add communication logic before and after linear call.
if
not
weight
.
has_spec
():
# No Model Parallel Applied
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
not
bias
.
has_spec
(),
'Invalid bias spec for native Linear op'
input_tensor
=
input_tensor
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
bias
=
bias
.
torch_tensor
()
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
if
ComputePattern
.
TP1DRow
in
compute_patterns
:
return
colo_linear_1Drow
(
input_tensor
,
weight
,
bias
)
...
...
colossalai/tensor/colo_tensor.py
View file @
676f1915
...
...
@@ -7,6 +7,13 @@ from colossalai.core import global_context as gpc
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ShardPattern
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
gather_forward_split_backward
from
enum
import
Enum
class
TensorType
(
Enum
):
MODEL
=
0
NONMODEL
=
1
# mainly activations
class
ColoTensor
(
object
):
""" Data Structure for Tensor in Colossal-AI
...
...
@@ -28,6 +35,7 @@ class ColoTensor(object):
device
=
None
,
torch_tensor
=
torch
.
empty
(
0
),
shard_spec
:
TensorSpec
=
TensorSpec
(),
is_model_data
:
bool
=
False
,
):
self
.
_size
=
size
self
.
_dtype
=
dtype
...
...
@@ -37,6 +45,10 @@ class ColoTensor(object):
self
.
_torch_tensor
=
torch_tensor
self
.
_shard_spec
=
shard_spec
self
.
_shard_pattern
=
ShardPattern
.
NA
if
is_model_data
:
self
.
_type
=
TensorType
.
MODEL
else
:
self
.
_type
=
TensorType
.
NONMODEL
def
__getitem__
(
self
,
key
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()[
key
])
...
...
@@ -85,13 +97,14 @@ class ColoTensor(object):
return
product
(
self
.
_size
)
@
staticmethod
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
)
->
'ColoTensor'
:
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
,
is_model_data
=
False
)
->
'ColoTensor'
:
colo_t
=
ColoTensor
(
*
tensor
.
size
(),
dtype
=
tensor
.
dtype
,
requires_grad
=
tensor
.
requires_grad
,
pin_memory
=
tensor
.
is_pinned
(),
device
=
tensor
.
device
,
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
))
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
),
is_model_data
=
is_model_data
)
return
colo_t
def
del_torch_tensor
(
self
,
save_shape
=
False
)
->
None
:
...
...
@@ -120,31 +133,28 @@ class ColoTensor(object):
self
.
_shard_spec
=
spec
if
shard
==
True
:
self
.
shard
()
def
set_shard_pattern
(
self
,
shard_pattern
:
ShardPattern
):
self
.
_shard_pattern
=
shard_pattern
def
shard
(
self
):
assert
self
.
_shard_spec
is
not
None
,
'You should call set_spec() before _shard() ColoTensor.'
if
self
.
_shard_pattern
is
not
ShardPattern
.
NA
:
# reshard
if
self
.
_shard_pattern
is
not
ShardPattern
.
NA
:
# reshard
self
.
gather
()
# Model Parameters
if
ComputePattern
.
TP1DRow
in
self
.
_shard_spec
.
compute_patterns
:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=-
1
)
self
.
_shard_pattern
=
ShardPattern
.
Col
# We bind our ComputePattern on weight, which has to be transposed when linear().
self
.
_shard_pattern
=
ShardPattern
.
Col
# We bind our ComputePattern on weight, which has to be transposed when linear().
elif
ComputePattern
.
TP1DCol
in
self
.
_shard_spec
.
compute_patterns
:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
)
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
)
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=
0
)
self
.
_shard_pattern
=
ShardPattern
.
Row
def
gather
(
self
):
assert
self
.
is_activation
(),
'Currently we only support gather Activation ColoTensor.'
assert
not
self
.
is_gathered
(),
'Only sharded ColoTensor can be gathered.'
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
Activation
)
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
DP
)
if
self
.
_shard_pattern
==
ShardPattern
.
Row
:
dim
=
0
elif
self
.
_shard_pattern
==
ShardPattern
.
Col
:
...
...
@@ -159,9 +169,8 @@ class ColoTensor(object):
return
self
.
_shard_spec
is
not
None
and
self
.
_shard_spec
.
num_action
>
0
def
is_activation
(
self
)
->
bool
:
return
self
.
_shard_spec
is
not
None
and
self
.
_shard_spec
.
num_action
==
1
\
and
ComputePattern
.
Activation
in
self
.
_shard_spec
.
compute_patterns
return
self
.
_type
==
TensorType
.
NONMODEL
def
_shard_1d
(
self
,
parallel_action
,
dim
=-
1
):
num_partition
=
gpc
.
get_world_size
(
parallel_action
.
parallel_mode
)
local_rank
=
gpc
.
get_local_rank
(
parallel_action
.
parallel_mode
)
...
...
@@ -169,8 +178,8 @@ class ColoTensor(object):
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
)
.
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
).
contiguous
(
)
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
self
.
_size
=
self
.
_torch_tensor
.
size
()
...
...
colossalai/tensor/spec.py
View file @
676f1915
...
...
@@ -4,20 +4,25 @@ from colossalai.context.parallel_mode import ParallelMode
class
ComputePattern
(
Enum
):
Activation
=
0
# TODO(jzy) A tmp place to store Activation info. Find a better place in future.
TP1DRow
=
1
TP1DCol
=
2
ZeRO
=
3
DP
=
4
class
ShardPattern
(
Enum
):
NA
=
0
Row
=
1
Col
=
2
class
ParallelAction
(
object
):
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
,
gather_out
=
True
)
->
None
:
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
,
gather_out
=
True
)
->
None
:
self
.
priority
=
priority
self
.
compute_pattern
=
compute_pattern
self
.
parallel_mode
=
parallel_mode
...
...
@@ -64,7 +69,7 @@ class TensorSpec(object):
@
property
def
compute_patterns
(
self
):
return
[
parallel_action
.
compute_pattern
for
parallel_action
in
self
.
_parallel_action_list
]
@
property
def
shard_pattern
(
self
):
return
self
.
_shard_pattern
...
...
colossalai/utils/model/colo_init_context.py
View file @
676f1915
...
...
@@ -94,7 +94,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload
=
True
if
not
self
.
_lazy_memory_allocate
else
False
for
name
,
param
in
name_list
:
delattr
(
module
,
name
)
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
to
(
self
.
_device
),
save_payload
=
save_torch_payload
))
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
to
(
self
.
_device
),
save_payload
=
save_torch_payload
,
is_model_data
=
True
))
ColoModulize
(
module
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment