Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
96211c2c
Unverified
Commit
96211c2c
authored
Apr 26, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 26, 2022
Browse files
[tensor] customized op returns ColoTensor (#875)
* [tensor] customized op returns ColoTensor * polish * polish code
parent
26d4ab8b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
45 deletions
+33
-45
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+1
-2
colossalai/tensor/_ops/element_wise.py
colossalai/tensor/_ops/element_wise.py
+19
-0
colossalai/tensor/_ops/init.py
colossalai/tensor/_ops/init.py
+0
-29
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+5
-4
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+6
-2
tests/test_tensor/test_net_tp.py
tests/test_tensor/test_net_tp.py
+0
-6
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+2
-2
No files found.
colossalai/tensor/_ops/__init__.py
View file @
96211c2c
from
.init
import
colo_uniform
from
.linear
import
colo_linear
from
.linear
import
colo_linear
from
.element_wise
import
colo_mean
from
.element_wise
import
*
from
.layernorm
import
colo_layernorm
from
.layernorm
import
colo_layernorm
from
.loss
import
colo_cross_entropy
from
.loss
import
colo_cross_entropy
colossalai/tensor/_ops/element_wise.py
View file @
96211c2c
...
@@ -29,3 +29,22 @@ def register_elementwise_op(op):
...
@@ -29,3 +29,22 @@ def register_elementwise_op(op):
register_elementwise_op
(
torch
.
nn
.
functional
.
gelu
)
register_elementwise_op
(
torch
.
nn
.
functional
.
gelu
)
register_elementwise_op
(
torch
.
nn
.
functional
.
relu
)
register_elementwise_op
(
torch
.
nn
.
functional
.
relu
)
@
colo_op_impl
(
torch
.
sum
)
def
sum_op
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.sum`.
This method computes on either a normal tensor or a sharded tensor.
"""
if
len
(
args
)
>
0
:
input_tensor
=
args
[
0
]
if
kwargs
is
None
:
kwargs
=
{}
if
'input'
in
kwargs
:
input_tensor
=
kwargs
[
'input'
]
# Validate types
if
not
isinstance
(
input_tensor
,
ColoTensor
):
raise
TypeError
(
"input needs to be a ColoTensor"
)
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
sum
(
input_tensor
.
torch_tensor
()))
colossalai/tensor/_ops/init.py
deleted
100644 → 0
View file @
26d4ab8b
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
def
validate_param
(
param
,
param_name
):
if
param
is
None
:
raise
ValueError
(
f
"param:
{
param_name
}
shouldn't be None!"
)
@
colo_op_impl
(
torch
.
nn
.
init
.
uniform_
)
def
colo_uniform
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
r
"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
sharded_tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param
(
kwargs
,
"kwargs"
)
stateful_tensor
=
kwargs
[
"tensor"
]
validate_param
(
stateful_tensor
,
"stateful_tensor"
)
a
=
kwargs
[
'a'
]
validate_param
(
a
,
"a"
)
b
=
kwargs
[
'b'
]
validate_param
(
b
,
"b"
)
torch
.
nn
.
init
.
uniform_
(
stateful_tensor
.
torch_tensor
(),
a
=
a
,
b
=
b
)
return
stateful_tensor
colossalai/tensor/_ops/linear.py
View file @
96211c2c
...
@@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
...
@@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from
colossalai.nn.layer.utils
import
divide
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
packaging
import
version
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
colossalai.tensor
import
ComputePattern
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
...
@@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
bias
=
kwargs
.
get
(
'bias'
,
None
)
bias
=
kwargs
.
get
(
'bias'
,
None
)
if
isinstance
(
bias
,
ColoTensor
):
if
isinstance
(
bias
,
ColoTensor
):
assert
bias
.
shard_spec
.
num_action
==
0
,
f
"We currently only support bias is duplicated among processes in the linear operator"
bias
=
bias
.
torch_tensor
()
bias
=
bias
.
torch_tensor
()
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
...
@@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
input_tensor
=
input_tensor
.
torch_tensor
()
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
)
elif
weight
.
shard_spec
.
num_action
==
1
:
elif
weight
.
shard_spec
.
num_action
==
1
:
if
ComputePattern
.
TP1DRow
in
weight
.
shard_spec
.
compute_patterns
:
if
ComputePattern
.
TP1DRow
in
weight
.
shard_spec
.
compute_patterns
:
# Input:S[1] x Weight:S[0] = Output:P
# Input:S[1] x Weight:S[0] = Output:P
...
@@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
# Bias
# Bias
if
bias
is
not
None
:
if
bias
is
not
None
:
bias_
=
bias
output
=
output
+
bias
output
=
output
+
bias_
return
ColoTensor
.
init_from_torch_tensor
(
output
)
return
ColoTensor
.
init_from_torch_tensor
(
output
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
colossalai/tensor/spec.py
View file @
96211c2c
...
@@ -2,6 +2,7 @@ from enum import Enum
...
@@ -2,6 +2,7 @@ from enum import Enum
from
typing
import
Tuple
,
List
from
typing
import
Tuple
,
List
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
class
ComputePattern
(
Enum
):
class
ComputePattern
(
Enum
):
TP1DRow
=
1
TP1DRow
=
1
TP1DCol
=
2
TP1DCol
=
2
...
@@ -10,6 +11,7 @@ class ComputePattern(Enum):
...
@@ -10,6 +11,7 @@ class ComputePattern(Enum):
class
ParallelAction
(
object
):
class
ParallelAction
(
object
):
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
)
->
None
:
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
)
->
None
:
self
.
priority
=
priority
self
.
priority
=
priority
self
.
compute_pattern
=
compute_pattern
self
.
compute_pattern
=
compute_pattern
...
@@ -24,6 +26,7 @@ class TensorSpec(object):
...
@@ -24,6 +26,7 @@ class TensorSpec(object):
parallel computation pattern of the Operator (Layer).
parallel computation pattern of the Operator (Layer).
We have to consider the hybrid parallel mode.
We have to consider the hybrid parallel mode.
"""
"""
# a list of parallel actions.
# a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
...
@@ -38,6 +41,7 @@ class TensorSpec(object):
...
@@ -38,6 +41,7 @@ class TensorSpec(object):
# Before Linear Op, we gather the tensors according to ZeRO.
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
# After Linear Op, we split the tensors according to ZeRO.
def
__init__
(
self
,
parallel_action_list
:
List
[
ParallelAction
]
=
[]):
def
__init__
(
self
,
parallel_action_list
:
List
[
ParallelAction
]
=
[]):
self
.
_parallel_action_list
=
parallel_action_list
self
.
_parallel_action_list
=
parallel_action_list
self
.
sort
()
self
.
sort
()
...
@@ -56,8 +60,8 @@ class TensorSpec(object):
...
@@ -56,8 +60,8 @@ class TensorSpec(object):
def
sort
(
self
):
def
sort
(
self
):
if
len
(
self
.
_parallel_action_list
)
>
0
:
if
len
(
self
.
_parallel_action_list
)
>
0
:
self
.
_parallel_action_list
.
sort
(
key
=
lambda
parallel_action
:
parallel_action
.
priority
)
self
.
_parallel_action_list
.
sort
(
key
=
lambda
parallel_action
:
parallel_action
.
priority
)
def
get_action_by_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
def
get_action_by_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
for
parallel_action
in
self
.
_parallel_action_list
:
for
parallel_action
in
self
.
_parallel_action_list
:
if
parallel_action
.
compute_pattern
==
compute_pattern
:
if
parallel_action
.
compute_pattern
==
compute_pattern
:
...
...
tests/test_tensor/test_net_tp.py
View file @
96211c2c
from
cProfile
import
label
from
statistics
import
mode
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
import
colossalai
import
colossalai
import
pytest
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
ColoInitContext
from
colossalai.utils
import
ColoInitContext
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
...
...
tests/test_tensor/test_op.py
View file @
96211c2c
...
@@ -53,11 +53,11 @@ def test_linear():
...
@@ -53,11 +53,11 @@ def test_linear():
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out
=
fc
(
input_tensor
)
out
=
fc
(
input_tensor
)
loss
=
out
.
sum
()
loss
=
torch
.
sum
(
out
)
loss
.
backward
()
loss
.
backward
()
out_ref
=
fc_ref
(
input_ref
)
out_ref
=
fc_ref
(
input_ref
)
loss_ref
=
out_ref
.
sum
(
)
loss_ref
=
torch
.
sum
(
out_ref
)
loss_ref
.
backward
()
loss_ref
.
backward
()
assert
(
loss_ref
==
loss
)
assert
(
loss_ref
==
loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment