Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0ce8924c
Unverified
Commit
0ce8924c
authored
Apr 21, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 21, 2022
Browse files
[tensor] reorganize files (#820)
parent
ab962b97
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
71 additions
and
76 deletions
+71
-76
colossalai/gemini/tensor/_ops/__init__.py
colossalai/gemini/tensor/_ops/__init__.py
+0
-3
colossalai/gemini/tensor/api.py
colossalai/gemini/tensor/api.py
+0
-17
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+7
-0
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+3
-0
colossalai/tensor/_ops/element_wise.py
colossalai/tensor/_ops/element_wise.py
+7
-7
colossalai/tensor/_ops/init.py
colossalai/tensor/_ops/init.py
+3
-3
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+6
-6
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+10
-11
colossalai/tensor/op_wrapper.py
colossalai/tensor/op_wrapper.py
+25
-10
colossalai/tensor/utils.py
colossalai/tensor/utils.py
+5
-9
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+5
-10
No files found.
colossalai/gemini/tensor/_ops/__init__.py
deleted
100644 → 0
View file @
ab962b97
from
.init
import
stateful_uniform
from
.linear
import
stateful_linear
from
.element_wise
import
stateful_mean
\ No newline at end of file
colossalai/gemini/tensor/api.py
deleted
100644 → 0
View file @
ab962b97
from
typing
import
(
Callable
,
Dict
,
)
# Custom sharded ops
_STATEFUL_OPS
:
Dict
[
str
,
Callable
]
=
{}
def
_register_stateful_op
(
op
,
func
):
from
inspect
import
signature
if
len
(
signature
(
func
).
parameters
)
!=
4
:
raise
TypeError
(
f
'Custom stateful op function expects signature: '
f
'(types, args, kwargs, process_group), but received '
f
'signature:
{
signature
(
func
)
}
'
)
global
_STATEFUL_OPS
_STATEFUL_OPS
[
op
]
=
func
colossalai/tensor/__init__.py
0 → 100644
View file @
0ce8924c
from
.op_wrapper
import
(
colo_op_impl
,)
from
.colo_tensor
import
ColoTensor
from
.utils
import
convert_parameter
from
._ops
import
*
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
]
colossalai/tensor/_ops/__init__.py
0 → 100644
View file @
0ce8924c
from
.init
import
colo_uniform
from
.linear
import
colo_linear
from
.element_wise
import
colo_mean
\ No newline at end of file
colossalai/
gemini/
tensor/_ops/element_wise.py
→
colossalai/tensor/_ops/element_wise.py
View file @
0ce8924c
import
torch
from
colossalai.
gemini.tensor
import
stateful
_op_impl
from
colossalai.
gemini.tensor.stateful_
tensor
import
Stateful
Tensor
V2
from
colossalai.
tensor.op_wrapper
import
colo
_op_impl
from
colossalai.tensor
import
Colo
Tensor
@
stateful
_op_impl
(
torch
.
mean
)
def
stateful
_mean
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
@
colo
_op_impl
(
torch
.
mean
)
def
colo
_mean
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
stateful_tensor
=
args
[
0
]
return
torch
.
mean
(
stateful_tensor
.
torch_tensor
())
def
register_elementwise_op
(
op
):
@
stateful
_op_impl
(
op
)
@
colo
_op_impl
(
op
)
def
elementwise_op
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
...
...
@@ -20,8 +20,8 @@ def register_elementwise_op(op):
"""
input_tensor
=
args
[
0
]
# Validate types
if
not
isinstance
(
input_tensor
,
Stateful
Tensor
V2
):
raise
TypeError
(
"input needs to be a
Stateful
Tensor
V2
"
)
if
not
isinstance
(
input_tensor
,
Colo
Tensor
):
raise
TypeError
(
"input needs to be a
Colo
Tensor"
)
return
op
(
input_tensor
.
torch_tensor
())
...
...
colossalai/
gemini/
tensor/_ops/init.py
→
colossalai/tensor/_ops/init.py
View file @
0ce8924c
import
torch
from
colossalai.
gemini.tensor
import
stateful
_op_impl
from
colossalai.
tensor.op_wrapper
import
colo
_op_impl
def
validate_param
(
param
,
param_name
):
...
...
@@ -7,8 +7,8 @@ def validate_param(param, param_name):
raise
ValueError
(
f
"param:
{
param_name
}
shouldn't be None!"
)
@
stateful
_op_impl
(
torch
.
nn
.
init
.
uniform_
)
def
stateful
_uniform
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
@
colo
_op_impl
(
torch
.
nn
.
init
.
uniform_
)
def
colo
_uniform
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
r
"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
...
...
colossalai/
gemini/
tensor/_ops/linear.py
→
colossalai/tensor/_ops/linear.py
View file @
0ce8924c
import
torch
from
colossalai.
gemini.tensor
import
stateful
_op_impl
from
..stateful
_tensor
import
Stateful
Tensor
V2
from
colossalai.
tensor.op_wrapper
import
colo
_op_impl
from
colossalai.tensor.colo
_tensor
import
Colo
Tensor
from
packaging
import
version
@
stateful
_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
stateful
_linear
(
types
,
args
,
kwargs
,
pg
):
@
colo
_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo
_linear
(
types
,
args
,
kwargs
,
pg
):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
...
...
@@ -19,11 +19,11 @@ def stateful_linear(types, args, kwargs, pg):
bias
=
None
else
:
bias
=
kwargs
.
get
(
'bias'
,
None
)
if
isinstance
(
bias
,
Stateful
Tensor
V2
):
if
isinstance
(
bias
,
Colo
Tensor
):
bias
=
bias
.
torch_tensor
()
# Add communication logic before and after linear call.
if
isinstance
(
weight
,
Stateful
Tensor
V2
):
if
isinstance
(
weight
,
Colo
Tensor
):
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
.
torch_tensor
(),
bias
)
else
:
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
colossalai/
gemini/tensor/stateful
_tensor.py
→
colossalai/
tensor/colo
_tensor.py
View file @
0ce8924c
import
torch
from
.
api
import
_
STATEFU
L_OPS
from
.
op_wrapper
import
_
COLOSSA
L_OPS
class
Stateful
Tensor
V2
(
object
):
class
Colo
Tensor
(
object
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
return
super
(
Stateful
Tensor
V2
,
cls
).
__new__
(
cls
)
return
super
(
Colo
Tensor
,
cls
).
__new__
(
cls
)
def
__init__
(
self
,
t
:
torch
.
Tensor
)
->
None
:
self
.
_torch_tensor
=
t
...
...
@@ -15,16 +15,15 @@ class StatefulTensorV2(object):
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
global
_STATEFUL_OPS
if
func
in
_STATEFUL_OPS
:
# Find StatefulTensorV2 instance to get process_group.
global
_COLOSSAL_OPS
if
func
in
_COLOSSAL_OPS
:
for
arg
in
args
:
if
isinstance
(
arg
,
Stateful
Tensor
V2
):
return
_
STATEFU
L_OPS
[
func
](
types
,
args
,
kwargs
,
None
)
if
isinstance
(
arg
,
Colo
Tensor
):
return
_
COLOSSA
L_OPS
[
func
](
types
,
args
,
kwargs
,
None
)
for
kwarg
in
kwargs
.
values
():
if
isinstance
(
kwarg
,
Stateful
Tensor
V2
):
return
_
STATEFU
L_OPS
[
func
](
types
,
args
,
kwargs
,
None
)
if
isinstance
(
kwarg
,
Colo
Tensor
):
return
_
COLOSSA
L_OPS
[
func
](
types
,
args
,
kwargs
,
None
)
raise
RuntimeError
(
f
"torch function '
{
func
.
__name__
}
', with args:
{
args
}
and "
f
"kwargs:
{
kwargs
}
not supported for
Stateful
Tensor
V2
!"
)
f
"kwargs:
{
kwargs
}
not supported for
Colo
Tensor!"
)
colossalai/
gemini/tensor/__init__
.py
→
colossalai/
tensor/op_wrapper
.py
View file @
0ce8924c
from
typing
import
(
Callable
,
Dict
,
)
import
functools
from
.api
import
(
_register_stateful_op
,)
# Custom sharded ops
_COLOSSAL_OPS
:
Dict
[
str
,
Callable
]
=
{}
def
stateful_op_impl
(
func
):
def
_register_colo_op
(
op
,
func
):
from
inspect
import
signature
if
len
(
signature
(
func
).
parameters
)
!=
4
:
raise
TypeError
(
f
'Custom stateful op function expects signature: '
f
'(types, args, kwargs, process_group), but received '
f
'signature:
{
signature
(
func
)
}
'
)
global
_COLOSSAL_OPS
_COLOSSAL_OPS
[
op
]
=
func
def
colo_op_impl
(
func
):
"""
Provides a way for users to write their own custom operator. This
can be used to override existing
Stateful
Tensor
V2
operators or write a new
one not supported by
Stateful
Tensor
V2
. If the operator in question is covered
by ``__torch_function__`` dispatch and has a
Stateful
Tensor
V2
as any of its
can be used to override existing
Colo
Tensor operators or write a new
one not supported by
Colo
Tensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a
Colo
Tensor as any of its
parameters, the function provided will be invoked for that operator.
Example::
>>> @
stateful
_op_impl(torch.nn.functional.linear)
>>> @
colo
_op_impl(torch.nn.functional.linear)
>>> def my_custom_linear(types, args, kwargs, process_group):
>>> ....
>>>
>>> input = torch.rand(10, 32)
>>> weight =
Stateful
Tensor
V2
(torch.rand(32, 16))
>>> bias =
Stateful
Tensor
V2
(torch.rand(16))
>>> weight =
Colo
Tensor(torch.rand(32, 16))
>>> bias =
Colo
Tensor(torch.rand(16))
>>> # This will call `my_custom_linear` instead of the default.
>>> torch.nn.functional.linear(input, weight, bias)
...
...
@@ -32,7 +47,7 @@ def stateful_op_impl(func):
"""
def
decorator_sharded_func
(
wrapped_func
):
_register_
stateful
_op
(
func
,
wrapped_func
)
_register_
colo
_op
(
func
,
wrapped_func
)
@
functools
.
wraps
(
wrapped_func
)
def
wrapper
(
*
args
,
**
kwargs
):
...
...
colossalai/
gemini/
tensor/utils.py
→
colossalai/tensor/utils.py
View file @
0ce8924c
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
distributed_c10d
from
colossalai.
gemini.tensor.stateful
_tensor
import
Stateful
Tensor
V2
from
colossalai.
tensor.colo
_tensor
import
Colo
Tensor
def
_convert_tensor
(
tensor
:
torch
.
Tensor
)
->
StatefulTensorV2
:
if
not
tensor
.
is_contiguous
():
raise
ValueError
(
'input tensor is not a contiguous Tensor'
)
return
StatefulTensorV2
(
tensor
)
def
_convert_tensor
(
tensor
:
torch
.
Tensor
)
->
ColoTensor
:
return
ColoTensor
(
tensor
)
def
convert_parameter
(
module
:
torch
.
nn
.
Module
,
param_name
:
str
):
...
...
@@ -26,10 +22,10 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
st
=
_convert_tensor
(
tensor
)
# Replace param with
Stateful
Tensor
V2
.
# Replace param with
Colo
Tensor.
# Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with
Stateful
Tensor
V2
which is
# torch.nn.Parameter and can't be replaced with
Colo
Tensor which is
# not torch.nn.Parameter.
delattr
(
module
,
param_name
)
...
...
tests/test_
gemini/test_tensor
.py
→
tests/test_
tensor/test_op
.py
View file @
0ce8924c
from
numpy
import
allclose
import
torch
from
torch
import
nn
from
colossalai.gemini.tensor.stateful_tensor
import
StatefulTensorV2
# TODO(jiaruifang) auto import
from
colossalai.gemini.tensor._ops
import
*
from
colossalai.gemini.tensor.api
import
_STATEFUL_OPS
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
...
...
@@ -18,8 +14,8 @@ def test_linear():
input_ref
=
torch
.
randn
(
1
,
in_dim
)
input_tensor
=
input_ref
.
clone
()
sharded_weight
=
Stateful
Tensor
V2
(
fc_ref
.
weight
)
sharded_bias
=
Stateful
Tensor
V2
(
fc_ref
.
bias
)
sharded_weight
=
Colo
Tensor
(
fc_ref
.
weight
)
sharded_bias
=
Colo
Tensor
(
fc_ref
.
bias
)
# replace the torch nn.Parameters with ShardedTensor
delattr
(
fc
,
'weight'
)
...
...
@@ -45,15 +41,14 @@ def test_linear():
# The test case failed
# def test_uniform():
# t = StatefulTensorV2(torch.zeros(3, 5))
# # print(_STATEFUL_OPS)
# t = ColoTensor(torch.zeros(3, 5))
# torch.nn.init.uniform_(t)
# print(t)
def
test_element_wise
():
t_ref
=
torch
.
randn
(
3
,
5
)
t
=
Stateful
Tensor
V2
(
t_ref
.
clone
())
t
=
Colo
Tensor
(
t_ref
.
clone
())
assert
torch
.
mean
(
t
)
==
torch
.
mean
(
t_ref
)
assert
allclose
(
torch
.
nn
.
functional
.
gelu
(
t
),
torch
.
nn
.
functional
.
gelu
(
t_ref
))
assert
allclose
(
torch
.
nn
.
functional
.
relu
(
t
),
torch
.
nn
.
functional
.
relu
(
t_ref
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment