Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1d0aba41
Unverified
Commit
1d0aba41
authored
Apr 27, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 27, 2022
Browse files
[tensor] add ColoTensor 1Dcol (#888)
parent
a0e59716
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
166 additions
and
28 deletions
+166
-28
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+81
-13
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+19
-12
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+2
-2
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+64
-1
No files found.
colossalai/tensor/_ops/linear.py
View file @
1d0aba41
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
,
\
gather_forward_split_backward
,
reduce_grad
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
colossalai.tensor
import
ComputePattern
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
...
...
@@ -25,39 +25,107 @@ def colo_linear(types, args, kwargs, pg):
else
:
bias
=
kwargs
.
get
(
'bias'
,
None
)
bias_spec
=
None
if
isinstance
(
bias
,
ColoTensor
):
assert
bias
.
shard_spec
.
num_action
==
0
,
f
"We currently only support bias is duplicated among processes in the linear operator"
bias_spec
=
bias
.
shard_spec
bias
=
bias
.
torch_tensor
()
# Add communication logic before and after linear call.
if
isinstance
(
weight
,
ColoTensor
):
if
weight
.
shard_spec
==
None
or
weight
.
shard_spec
.
num_action
==
0
:
assert
bias_spec
==
None
or
bias_spec
.
num_action
==
0
,
'Invalid bias spec for native Linear op'
if
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
shard_spec
.
num_action
==
1
:
if
ComputePattern
.
TP1DRow
in
weight
.
shard_spec
.
compute_patterns
:
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
if
ComputePattern
.
TP1DRow
in
compute_patterns
:
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
(
-
1
),
\
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
(
-
1
)
*
gpc
.
tensor_parallel_size
)
# Input:S[1]
input_spec
=
None
if
isinstance
(
input_tensor
,
ColoTensor
):
input_spec
=
input_tensor
.
shard_spec
input_tensor
=
input_tensor
.
torch_tensor
()
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
parallel_action
.
parallel_mode
,
dim
=-
1
)
if
input_spec
==
None
or
input_spec
.
num_action
==
0
:
# Not splited yet.
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
(
-
1
),
\
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
(
-
1
)
*
gpc
.
tensor_parallel_size
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
parallel_action
.
parallel_mode
,
dim
=-
1
)
elif
input_tensor
.
shard_spec
.
num_action
==
1
:
if
ComputePattern
.
TP1DCol
in
input_spec
.
compute_patterns
:
# Splited by 1Dcol
assert
input_tensor
.
shape
[
-
1
]
==
weight
.
size
(
-
1
),
\
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
(
-
1
))
input_per_partition
=
input_tensor
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
# Output:P
weight_
=
weight
.
torch_tensor
()
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight_
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
P
arallel
M
ode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
parallel_action
.
p
arallel
_m
ode
)
# Bias
if
bias
is
not
None
:
assert
bias_spec
==
None
or
bias_spec
.
num_action
==
0
,
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
return
ColoTensor
.
init_from_torch_tensor
(
output
)
output
=
ColoTensor
.
init_from_torch_tensor
(
output
)
return
output
elif
ComputePattern
.
TP1DCol
in
compute_patterns
:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
input_spec
=
None
output_spec
=
None
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
)
if
isinstance
(
input_tensor
,
ColoTensor
):
input_spec
=
input_tensor
.
shard_spec
input_tensor
=
input_tensor
.
torch_tensor
()
if
input_spec
==
None
or
input_spec
.
num_action
==
0
:
# Not splited yet.
assert
input_tensor
.
shape
[
-
1
]
==
weight
.
size
(
-
1
),
\
'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
(
-
1
))
input_parallel
=
reduce_grad
(
input_tensor
,
parallel_action
.
parallel_mode
)
else
:
raise
NotImplementedError
# Bias:S[1]
if
bias
is
not
None
:
assert
bias_spec
is
not
None
and
bias_spec
.
num_action
==
1
and
\
ComputePattern
.
TP1DCol
in
bias_spec
.
compute_patterns
,
\
'Invalid bias spec for 1Dcol Linear op'
weight_
=
weight
.
torch_tensor
()
output_parallel
=
torch
.
nn
.
functional
.
linear
(
input_parallel
,
weight_
,
bias
)
if
parallel_action
.
gather_out
:
# All-Gather(Output)
output
=
gather_forward_split_backward
(
output_parallel
,
parallel_action
.
parallel_mode
,
dim
=-
1
)
output
=
ColoTensor
.
init_from_torch_tensor
(
output
)
else
:
output
=
ColoTensor
.
init_from_torch_tensor
(
output_parallel
)
out_parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
parallel_action
.
parallel_mode
)
]
output_spec
=
TensorSpec
(
out_parallel_action_list
)
# set ColoTensor spec
if
output_spec
is
not
None
:
output
.
set_spec
(
output_spec
)
return
output
else
:
raise
NotImplementedError
else
:
...
...
colossalai/tensor/colo_tensor.py
View file @
1d0aba41
...
...
@@ -121,18 +121,25 @@ class ColoTensor(object):
assert
self
.
_shard_spec
is
not
None
,
'You should call set_spec() before _shard() ColoTensor.'
if
self
.
_shard_spec
.
num_action
==
1
:
if
ComputePattern
.
TP1DRow
in
self
.
_shard_spec
.
compute_patterns
:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
num_partition
=
gpc
.
get_world_size
(
parallel_action
.
parallel_mode
)
local_rank
=
gpc
.
get_local_rank
(
parallel_action
.
parallel_mode
)
dim
=
-
1
chunk_size
=
divide
(
self
.
_size
[
dim
],
num_partition
)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
).
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
self
.
_size
=
self
.
_torch_tensor
.
size
()
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=-
1
)
elif
ComputePattern
.
TP1DCol
in
self
.
_shard_spec
.
compute_patterns
:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
)
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=
0
)
def
_shard_1d
(
self
,
parallel_action
,
dim
=-
1
):
num_partition
=
gpc
.
get_world_size
(
parallel_action
.
parallel_mode
)
local_rank
=
gpc
.
get_local_rank
(
parallel_action
.
parallel_mode
)
chunk_size
=
divide
(
self
.
_size
[
dim
],
num_partition
)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
).
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
self
.
_size
=
self
.
_torch_tensor
.
size
()
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
...
...
colossalai/tensor/spec.py
View file @
1d0aba41
...
...
@@ -12,11 +12,11 @@ class ComputePattern(Enum):
class
ParallelAction
(
object
):
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
)
->
None
:
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
,
gather_out
=
True
)
->
None
:
self
.
priority
=
priority
self
.
compute_pattern
=
compute_pattern
self
.
parallel_mode
=
parallel_mode
self
.
gather_out
=
gather_out
class
TensorSpec
(
object
):
"""
...
...
tests/test_tensor/test_linear_tp.py
View file @
1d0aba41
...
...
@@ -16,6 +16,69 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from
_utils
import
check_equal
,
replace_parameter_add_grad
,
broadcast_tensor_chunk
def
run_linear_tp1d_col_test
():
device
=
get_current_device
()
dtype
=
torch
.
float32
DEPTH
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
in_features
=
4
out_features
=
8
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer_master
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
layer
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
A_shape
=
(
2
,
in_features
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A
=
broadcast_tensor_chunk
(
A_master
,
chunk_size
=
1
)
A
.
requires_grad
=
True
W_shape
=
(
out_features
,
in_features
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
W
=
broadcast_tensor_chunk
(
W_master
,
chunk_size
=
1
)
W
.
requires_grad
=
True
B_shape
=
(
out_features
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
B
=
broadcast_tensor_chunk
(
B_master
,
chunk_size
=
1
)
B
.
requires_grad
=
True
# replace the torch nn.Parameters with ColoTensor
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
sharded_weight
.
set_spec
(
spec
)
# reshard
sharded_bias
.
set_spec
(
spec
)
replace_parameter_add_grad
(
layer
,
sharded_weight
,
sharded_bias
)
out
=
layer
(
A
)
replace_parameter_add_grad
(
layer_master
,
W_master
,
B_master
)
A_master
.
requires_grad
=
True
#C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C_master
=
layer_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
grad
=
broadcast_tensor_chunk
(
grad_master
,
chunk_size
=
1
)
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
local_rank
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
B_master
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
local_rank
]
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
def
run_linear_tp1d_row_test
():
device
=
get_current_device
()
...
...
@@ -83,7 +146,7 @@ def run_dist(rank, world_size, port):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_linear_tp1d_row_test
()
run_linear_tp1d_col_test
()
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment