Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
26d4ab8b
Unverified
Commit
26d4ab8b
authored
Apr 26, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 26, 2022
Browse files
[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)
parent
11f54c7b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
85 additions
and
58 deletions
+85
-58
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+3
-1
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+1
-1
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+26
-23
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+21
-22
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+28
-10
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+6
-1
No files found.
colossalai/tensor/__init__.py
View file @
26d4ab8b
from
.spec
import
ComputePattern
,
ParallelAction
,
TensorSpec
from
.op_wrapper
import
(
from
.op_wrapper
import
(
colo_op_impl
,)
colo_op_impl
,)
from
.colo_tensor
import
ColoTensor
from
.colo_tensor
import
ColoTensor
from
.utils
import
convert_parameter
from
.utils
import
convert_parameter
from
._ops
import
*
from
._ops
import
*
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
]
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
]
colossalai/tensor/_ops/__init__.py
View file @
26d4ab8b
...
@@ -2,4 +2,4 @@ from .init import colo_uniform
...
@@ -2,4 +2,4 @@ from .init import colo_uniform
from
.linear
import
colo_linear
from
.linear
import
colo_linear
from
.element_wise
import
colo_mean
from
.element_wise
import
colo_mean
from
.layernorm
import
colo_layernorm
from
.layernorm
import
colo_layernorm
from
.loss
import
colo_cross_entropy
from
.loss
import
colo_cross_entropy
\ No newline at end of file
colossalai/tensor/_ops/linear.py
View file @
26d4ab8b
...
@@ -6,8 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
...
@@ -6,8 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from
colossalai.nn.layer.utils
import
divide
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
packaging
import
version
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
...
@@ -30,32 +29,36 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -30,32 +29,36 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
if
isinstance
(
weight
,
ColoTensor
):
if
isinstance
(
weight
,
ColoTensor
):
if
weight
.
shard_spec
==
None
:
if
weight
.
shard_spec
==
None
or
weight
.
shard_spec
.
num_action
==
0
:
if
isinstance
(
input_tensor
,
ColoTensor
):
if
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
input_tensor
.
torch_tensor
()
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
elif
weight
.
shard_spec
==
'1Drow'
:
elif
weight
.
shard_spec
.
num_action
==
1
:
# Input:S[1] x Weight:S[0] = Output:P
if
ComputePattern
.
TP1DRow
in
weight
.
shard_spec
.
compute_patterns
:
# All-Reduce(Output) + bias = res
# Input:S[1] x Weight:S[0] = Output:P
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
(
-
1
),
\
# All-Reduce(Output) + bias = res
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
(
-
1
),
\
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
[
-
1
]
*
gpc
.
tensor_parallel_size
)
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
# Input:S[1]
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
(
-
1
)
*
gpc
.
tensor_parallel_size
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
# Input:S[1]
# Output:P
if
isinstance
(
input_tensor
,
ColoTensor
):
device
=
get_current_device
()
# TODO where to put to(deivce)?
input_tensor
=
input_tensor
.
torch_tensor
()
weight_
=
weight
.
torch_tensor
().
to
(
device
)
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight_
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
parallel_action
.
parallel_mode
,
dim
=-
1
)
# Reduce(Output)
# Output:P
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
weight_
=
weight
.
torch_tensor
()
# Bias
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight_
)
if
bias
is
not
None
:
# Reduce(Output)
bias_
=
bias
.
to
(
device
)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
output
+
bias_
# Bias
return
output
if
bias
is
not
None
:
bias_
=
bias
output
=
output
+
bias_
return
ColoTensor
.
init_from_torch_tensor
(
output
)
else
:
raise
NotImplementedError
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
...
...
colossalai/tensor/colo_tensor.py
View file @
26d4ab8b
from
colossalai.context
import
parallel_mode
from
.op_wrapper
import
_COLOSSAL_OPS
from
.op_wrapper
import
_COLOSSAL_OPS
import
torch
import
torch
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
from
numpy
import
product
from
numpy
import
product
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer.utils
import
divide
from
colossalai.nn.layer.utils
import
divide
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
class
ColoTensor
(
object
):
class
ColoTensor
(
object
):
""" Data Structure for Tensor in Colossal-AI
""" Data Structure for Tensor in Colossal-AI
...
@@ -28,7 +27,7 @@ class ColoTensor(object):
...
@@ -28,7 +27,7 @@ class ColoTensor(object):
pin_memory
=
False
,
pin_memory
=
False
,
device
=
None
,
device
=
None
,
torch_tensor
=
torch
.
empty
(
0
),
torch_tensor
=
torch
.
empty
(
0
),
shard_spec
:
str
=
None
,
shard_spec
:
TensorSpec
=
TensorSpec
()
,
):
):
self
.
_size
=
size
self
.
_size
=
size
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
...
@@ -39,7 +38,7 @@ class ColoTensor(object):
...
@@ -39,7 +38,7 @@ class ColoTensor(object):
self
.
_shard_spec
=
shard_spec
self
.
_shard_spec
=
shard_spec
@
property
@
property
def
shard_spec
(
self
)
->
Optional
[
str
]
:
def
shard_spec
(
self
)
->
TensorSpec
:
return
self
.
_shard_spec
return
self
.
_shard_spec
@
property
@
property
...
@@ -109,27 +108,27 @@ class ColoTensor(object):
...
@@ -109,27 +108,27 @@ class ColoTensor(object):
device
=
self
.
_device
)
device
=
self
.
_device
)
return
self
.
_torch_tensor
return
self
.
_torch_tensor
def
set_spec
(
self
,
spec
:
str
,
lazy_shard
:
bool
=
False
)
->
None
:
def
set_spec
(
self
,
spec
:
TensorSpec
,
lazy_shard
:
bool
=
False
)
->
None
:
self
.
_shard_spec
=
spec
self
.
_shard_spec
=
spec
if
lazy_shard
==
False
:
if
lazy_shard
==
False
:
self
.
_shard
()
self
.
_shard
()
def
_shard
(
self
):
def
_shard
(
self
):
assert
self
.
_shard_spec
is
not
None
,
'You should call set_spec() before _shard() ColoTensor.'
assert
self
.
_shard_spec
is
not
None
,
'You should call set_spec() before _shard() ColoTensor.'
if
self
.
_shard_spec
==
"1Drow"
:
# TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
if
self
.
_shard_spec
.
num_action
==
1
:
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
if
ComputePattern
.
TP1DRow
in
self
.
_shard_spec
.
compute_patterns
:
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
dim
=
-
1
num_partition
=
gpc
.
get_world_size
(
parallel_action
.
parallel_mode
)
chunk_size
=
divide
(
self
.
_size
[
dim
],
num_partition
)
local_rank
=
gpc
.
get_local_rank
(
parallel_action
.
parallel_mode
)
device
=
get_current_device
()
dim
=
-
1
# Reshape to get shard for this rank and we don't want autograd
chunk_size
=
divide
(
self
.
_size
[
dim
],
num_partition
)
# recording here for the narrow op and 'local_shard' should be a
# Reshape to get shard for this rank and we don't want autograd
# leaf variable in the autograd graph.
# recording here for the narrow op and 'local_shard' should be a
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
# leaf variable in the autograd graph.
).
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
).
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_size
=
self
.
_torch_tensor
.
size
()
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
self
.
_
device
=
device
# TODO A `fake` device now because torch_tensor.device always = cpu
self
.
_
size
=
self
.
_torch_tensor
.
size
()
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
...
@@ -151,5 +150,5 @@ class ColoTensor(object):
...
@@ -151,5 +150,5 @@ class ColoTensor(object):
kwargs
=
{
k
:
v
.
torch_tensor
()
if
isinstance
(
v
,
ColoTensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
kwargs
=
{
k
:
v
.
torch_tensor
()
if
isinstance
(
v
,
ColoTensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
def
backward
(
self
,
retain_graph
:
bool
=
False
):
def
backward
(
self
,
gradient
:
Optional
[
torch
.
Tensor
]
=
None
,
retain_graph
:
bool
=
False
):
self
.
_torch_tensor
.
backward
(
retain_graph
=
retain_graph
)
self
.
_torch_tensor
.
backward
(
gradient
=
gradient
,
retain_graph
=
retain_graph
)
colossalai/tensor/spec.py
View file @
26d4ab8b
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Tuple
,
List
from
typing
import
Tuple
,
List
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
class
ComputePattern
(
Enum
):
class
ComputePattern
(
Enum
):
TP1DRow
=
1
TP1DRow
=
1
...
@@ -12,17 +10,13 @@ class ComputePattern(Enum):
...
@@ -12,17 +10,13 @@ class ComputePattern(Enum):
class
ParallelAction
(
object
):
class
ParallelAction
(
object
):
priority
=
0
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
)
->
None
:
compute_pattern
=
ComputePattern
.
DP
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
def
__init__
(
self
,
priority
,
compute_pattern
,
process_group
)
->
None
:
self
.
priority
=
priority
self
.
priority
=
priority
self
.
compute_pattern
=
compute_pattern
self
.
compute_pattern
=
compute_pattern
self
.
p
rocess_group
=
process_group
self
.
p
arallel_mode
=
parallel_mode
class
TensorSpec
(
Enum
):
class
TensorSpec
(
object
):
"""
"""
It contains two aspects of information:
It contains two aspects of information:
First, How are tensors distributed in Heterougenous memory space.
First, How are tensors distributed in Heterougenous memory space.
...
@@ -44,4 +38,28 @@ class TensorSpec(Enum):
...
@@ -44,4 +38,28 @@ class TensorSpec(Enum):
# Before Linear Op, we gather the tensors according to ZeRO.
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
# After Linear Op, we split the tensors according to ZeRO.
parallel_action_list
:
List
[
ParallelAction
]
=
[]
def
__init__
(
self
,
parallel_action_list
:
List
[
ParallelAction
]
=
[]):
self
.
_parallel_action_list
=
parallel_action_list
self
.
sort
()
@
property
def
parallel_action_list
(
self
):
return
self
.
_parallel_action_list
@
property
def
num_action
(
self
):
return
len
(
self
.
_parallel_action_list
)
@
property
def
compute_patterns
(
self
):
return
[
parallel_action
.
compute_pattern
for
parallel_action
in
self
.
_parallel_action_list
]
def
sort
(
self
):
if
len
(
self
.
_parallel_action_list
)
>
0
:
self
.
_parallel_action_list
.
sort
(
key
=
lambda
parallel_action
:
parallel_action
.
priority
)
def
get_action_by_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
for
parallel_action
in
self
.
_parallel_action_list
:
if
parallel_action
.
compute_pattern
==
compute_pattern
:
return
parallel_action
return
None
tests/test_tensor/test_linear_tp.py
View file @
26d4ab8b
...
@@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
...
@@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
_utils
import
check_equal
,
replace_parameter_add_grad
,
broadcast_tensor_chunk
from
_utils
import
check_equal
,
replace_parameter_add_grad
,
broadcast_tensor_chunk
...
@@ -45,7 +46,11 @@ def run_linear_tp1d_row_test():
...
@@ -45,7 +46,11 @@ def run_linear_tp1d_row_test():
# replace the torch nn.Parameters with ColoTensor
# replace the torch nn.Parameters with ColoTensor
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_weight
.
set_spec
(
spec
=
"1Drow"
)
# reshard
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
sharded_weight
.
set_spec
(
spec
=
spec
)
# reshard
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
replace_parameter_add_grad
(
layer
,
sharded_weight
,
sharded_bias
)
replace_parameter_add_grad
(
layer
,
sharded_weight
,
sharded_bias
)
out
=
layer
(
A
)
out
=
layer
(
A
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment