Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9bc5a77c
Unverified
Commit
9bc5a77c
authored
Apr 26, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 26, 2022
Browse files
[tensor] wrap function in the torch_tensor to ColoTensor (#881)
parent
4df6471f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
19 deletions
+41
-19
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+20
-18
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+21
-1
No files found.
colossalai/tensor/colo_tensor.py
View file @
9bc5a77c
...
...
@@ -2,7 +2,7 @@ from colossalai.context import parallel_mode
from
.op_wrapper
import
_COLOSSAL_OPS
import
torch
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
,
Callable
from
numpy
import
product
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.utils
import
divide
...
...
@@ -152,26 +152,28 @@ class ColoTensor(object):
kwargs
=
{}
kwargs
=
{
k
:
v
.
torch_tensor
()
if
isinstance
(
v
,
ColoTensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
return
ColoTensor
.
init_from_torch_tensor
(
func
(
*
args
,
**
kwargs
))
return
cls
.
_filter_outputs_with_colo
(
func
(
*
args
,
**
kwargs
))
def
backward
(
self
,
gradient
:
Optional
[
torch
.
Tensor
]
=
None
,
retain_graph
:
bool
=
False
):
self
.
_torch_tensor
.
backward
(
gradient
=
gradient
,
retain_graph
=
retain_graph
)
## TODO(fjr) we reduce redundency of the following code
def
__add__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
def
__getattr__
(
self
,
name
):
def
replace_tensor_with_colo
(
func
):
def
execute_func
(
*
args
,
**
kwargs
):
return
self
.
_filter_outputs_with_colo
(
func
(
*
args
,
**
kwargs
))
return
execute_func
def
__truediv__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
/
o
)
def
view
(
self
,
*
args
:
int
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
view
(
*
args
))
def
permute
(
self
,
*
args
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
permute
(
*
args
))
def
transpose
(
self
,
*
args
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
transpose
(
*
args
))
attr
=
getattr
(
self
.
_torch_tensor
,
name
)
if
isinstance
(
attr
,
Callable
):
return
replace_tensor_with_colo
(
attr
)
else
:
return
attr
def
contiguous
(
self
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
contiguous
())
@
classmethod
def
_filter_outputs_with_colo
(
cls
,
outputs
):
if
outputs
is
None
:
# return None
return
None
elif
type
(
outputs
)
is
not
tuple
:
# num of return val = 1
return
ColoTensor
.
init_from_torch_tensor
(
outputs
)
if
type
(
outputs
)
is
torch
.
Tensor
else
outputs
else
:
# num of return val > 1
return
tuple
([
ColoTensor
.
init_from_torch_tensor
(
output
)
if
type
(
output
)
is
torch
.
Tensor
else
output
for
output
in
outputs
])
tests/test_tensor/test_op.py
View file @
9bc5a77c
...
...
@@ -86,12 +86,32 @@ def test_no_wrap_op():
assert
torch
.
sum
(
t
)
==
torch
.
sum
(
t_ref
)
assert
torch
.
sum
(
input
=
t
)
==
torch
.
sum
(
input
=
t_ref
)
def
test_wrapped_tensor_func
():
t_ref
=
torch
.
randn
(
4
,
5
)
t
=
ColoTensor
.
init_from_torch_tensor
(
t_ref
.
clone
())
# non-func attr
assert
t
.
is_cuda
==
t_ref
.
is_cuda
# TODO I don't find out a tensor function which returns None.
# return 1 torch.Tensor
t_abs
=
t
.
abs
()
assert
isinstance
(
t_abs
,
ColoTensor
)
and
torch
.
equal
(
t_abs
.
torch_tensor
(),
t_ref
.
abs
())
# return 1 non-torch.Tensor
assert
t
.
dim
()
==
t_ref
.
dim
()
# return >1 torch.Tensor
t_split1
,
t_split2
=
t
.
split
(
2
)
assert
isinstance
(
t_split1
,
ColoTensor
)
and
isinstance
(
t_split2
,
ColoTensor
)
def
check_all
():
test_linear
()
test_element_wise
()
test_no_wrap_op
()
test_wrapped_tensor_func
()
if
__name__
==
'__main__'
:
check_all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment