Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
96211c2c
Unverified
Commit
96211c2c
authored
Apr 26, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 26, 2022
Browse files
[tensor] customized op returns ColoTensor (#875)
* [tensor] customized op returns ColoTensor * polish * polish code
parent
26d4ab8b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
45 deletions
+33
-45
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+1
-2
colossalai/tensor/_ops/element_wise.py
colossalai/tensor/_ops/element_wise.py
+19
-0
colossalai/tensor/_ops/init.py
colossalai/tensor/_ops/init.py
+0
-29
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+5
-4
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+6
-2
tests/test_tensor/test_net_tp.py
tests/test_tensor/test_net_tp.py
+0
-6
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+2
-2
No files found.
colossalai/tensor/_ops/__init__.py
View file @
96211c2c
from
.init
import
colo_uniform
from
.linear
import
colo_linear
from
.element_wise
import
colo_mean
from
.element_wise
import
*
from
.layernorm
import
colo_layernorm
from
.loss
import
colo_cross_entropy
colossalai/tensor/_ops/element_wise.py
View file @
96211c2c
...
...
@@ -29,3 +29,22 @@ def register_elementwise_op(op):
register_elementwise_op
(
torch
.
nn
.
functional
.
gelu
)
register_elementwise_op
(
torch
.
nn
.
functional
.
relu
)
@
colo_op_impl
(
torch
.
sum
)
def
sum_op
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.sum`.
This method computes on either a normal tensor or a sharded tensor.
"""
if
len
(
args
)
>
0
:
input_tensor
=
args
[
0
]
if
kwargs
is
None
:
kwargs
=
{}
if
'input'
in
kwargs
:
input_tensor
=
kwargs
[
'input'
]
# Validate types
if
not
isinstance
(
input_tensor
,
ColoTensor
):
raise
TypeError
(
"input needs to be a ColoTensor"
)
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
sum
(
input_tensor
.
torch_tensor
()))
colossalai/tensor/_ops/init.py
deleted
100644 → 0
View file @
26d4ab8b
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
def
validate_param
(
param
,
param_name
):
if
param
is
None
:
raise
ValueError
(
f
"param:
{
param_name
}
shouldn't be None!"
)
@
colo_op_impl
(
torch
.
nn
.
init
.
uniform_
)
def
colo_uniform
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
r
"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
sharded_tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param
(
kwargs
,
"kwargs"
)
stateful_tensor
=
kwargs
[
"tensor"
]
validate_param
(
stateful_tensor
,
"stateful_tensor"
)
a
=
kwargs
[
'a'
]
validate_param
(
a
,
"a"
)
b
=
kwargs
[
'b'
]
validate_param
(
b
,
"b"
)
torch
.
nn
.
init
.
uniform_
(
stateful_tensor
.
torch_tensor
(),
a
=
a
,
b
=
b
)
return
stateful_tensor
colossalai/tensor/_ops/linear.py
View file @
96211c2c
...
...
@@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
colossalai.tensor
import
ComputePattern
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
...
...
@@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
bias
=
kwargs
.
get
(
'bias'
,
None
)
if
isinstance
(
bias
,
ColoTensor
):
assert
bias
.
shard_spec
.
num_action
==
0
,
f
"We currently only support bias is duplicated among processes in the linear operator"
bias
=
bias
.
torch_tensor
()
# Add communication logic before and after linear call.
...
...
@@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
)
elif
weight
.
shard_spec
.
num_action
==
1
:
if
ComputePattern
.
TP1DRow
in
weight
.
shard_spec
.
compute_patterns
:
# Input:S[1] x Weight:S[0] = Output:P
...
...
@@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
# Bias
if
bias
is
not
None
:
bias_
=
bias
output
=
output
+
bias_
output
=
output
+
bias
return
ColoTensor
.
init_from_torch_tensor
(
output
)
else
:
raise
NotImplementedError
...
...
colossalai/tensor/spec.py
View file @
96211c2c
...
...
@@ -2,6 +2,7 @@ from enum import Enum
from
typing
import
Tuple
,
List
from
colossalai.context.parallel_mode
import
ParallelMode
class
ComputePattern
(
Enum
):
TP1DRow
=
1
TP1DCol
=
2
...
...
@@ -10,6 +11,7 @@ class ComputePattern(Enum):
class
ParallelAction
(
object
):
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
)
->
None
:
self
.
priority
=
priority
self
.
compute_pattern
=
compute_pattern
...
...
@@ -24,6 +26,7 @@ class TensorSpec(object):
parallel computation pattern of the Operator (Layer).
We have to consider the hybrid parallel mode.
"""
# a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
...
...
@@ -38,6 +41,7 @@ class TensorSpec(object):
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
def
__init__
(
self
,
parallel_action_list
:
List
[
ParallelAction
]
=
[]):
self
.
_parallel_action_list
=
parallel_action_list
self
.
sort
()
...
...
@@ -56,7 +60,7 @@ class TensorSpec(object):
def
sort
(
self
):
if
len
(
self
.
_parallel_action_list
)
>
0
:
self
.
_parallel_action_list
.
sort
(
key
=
lambda
parallel_action
:
parallel_action
.
priority
)
self
.
_parallel_action_list
.
sort
(
key
=
lambda
parallel_action
:
parallel_action
.
priority
)
def
get_action_by_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
for
parallel_action
in
self
.
_parallel_action_list
:
...
...
tests/test_tensor/test_net_tp.py
View file @
96211c2c
from
cProfile
import
label
from
statistics
import
mode
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
tests.components_to_test.registry
import
non_distributed_component_funcs
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
ColoInitContext
import
torch.distributed
as
dist
from
functools
import
partial
...
...
tests/test_tensor/test_op.py
View file @
96211c2c
...
...
@@ -53,11 +53,11 @@ def test_linear():
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out
=
fc
(
input_tensor
)
loss
=
out
.
sum
()
loss
=
torch
.
sum
(
out
)
loss
.
backward
()
out_ref
=
fc_ref
(
input_ref
)
loss_ref
=
out_ref
.
sum
(
)
loss_ref
=
torch
.
sum
(
out_ref
)
loss_ref
.
backward
()
assert
(
loss_ref
==
loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment