Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ea0a2ed2
Unverified
Commit
ea0a2ed2
authored
Apr 24, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 24, 2022
Browse files
[hotfix] the bug of numel() in ColoTensor (#845)
parent
c1e8d200
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
6 deletions
+21
-6
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+14
-4
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+7
-2
No files found.
colossalai/tensor/colo_tensor.py
View file @
ea0a2ed2
from
numpy
import
product
import
torch
import
torch
from
.op_wrapper
import
_COLOSSAL_OPS
from
typing
import
Tuple
from
typing
import
Tuple
import
numpy
from
.op_wrapper
import
_COLOSSAL_OPS
class
ColoTensor
(
object
):
class
ColoTensor
(
object
):
...
@@ -31,7 +33,7 @@ class ColoTensor(object):
...
@@ -31,7 +33,7 @@ class ColoTensor(object):
self
.
_torch_tensor
=
torch_tensor
self
.
_torch_tensor
=
torch_tensor
def
numel
(
self
):
def
numel
(
self
):
return
sum
(
self
.
_size
)
return
product
(
self
.
_size
)
@
staticmethod
@
staticmethod
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
)
->
'ColoTensor'
:
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
save_payload
=
True
)
->
'ColoTensor'
:
...
@@ -44,9 +46,17 @@ class ColoTensor(object):
...
@@ -44,9 +46,17 @@ class ColoTensor(object):
return
colo_t
return
colo_t
def
del_torch_tensor
(
self
,
save_shape
=
False
)
->
None
:
def
del_torch_tensor
(
self
,
save_shape
=
False
)
->
None
:
if
save_shape
:
"""
delete the payload of the torch tensor.
Args:
save_shape (bool, optional): if saving the shape of the torch_tensor.
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
Defaults to False.
"""
if
not
save_shape
:
self
.
_size
=
(
0
,)
self
.
_size
=
(
0
,)
self
.
_torch_tensor
=
torch
.
empty
((
0
,))
self
.
_torch_tensor
=
torch
.
empty
((
0
,)
,
device
=
self
.
_device
,
dtype
=
self
.
_dtype
)
def
torch_tensor
(
self
)
->
torch
.
Tensor
:
def
torch_tensor
(
self
)
->
torch
.
Tensor
:
if
self
.
_torch_tensor
.
numel
()
==
0
:
if
self
.
_torch_tensor
.
numel
()
==
0
:
...
...
tests/test_tensor/test_op.py
View file @
ea0a2ed2
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
from
copy
import
deepcopy
def
test_linear
():
def
test_linear
():
in_dim
=
4
in_dim
=
4
out_dim
=
5
out_dim
=
5
...
@@ -44,6 +45,7 @@ def test_linear():
...
@@ -44,6 +45,7 @@ def test_linear():
# torch.nn.init.uniform_(t)
# torch.nn.init.uniform_(t)
# print(t)
# print(t)
def
test_element_wise
():
def
test_element_wise
():
t_ref
=
torch
.
randn
(
3
,
5
)
t_ref
=
torch
.
randn
(
3
,
5
)
t
=
ColoTensor
.
init_from_torch_tensor
(
t_ref
.
clone
())
t
=
ColoTensor
.
init_from_torch_tensor
(
t_ref
.
clone
())
...
@@ -59,10 +61,12 @@ def test_no_wrap_op():
...
@@ -59,10 +61,12 @@ def test_no_wrap_op():
assert
torch
.
sum
(
t
)
==
torch
.
sum
(
t_ref
)
assert
torch
.
sum
(
t
)
==
torch
.
sum
(
t_ref
)
assert
torch
.
sum
(
input
=
t
)
==
torch
.
sum
(
input
=
t_ref
)
assert
torch
.
sum
(
input
=
t
)
==
torch
.
sum
(
input
=
t_ref
)
def
test_lazy_init_tensor
():
def
test_lazy_init_tensor
():
lazy_t
=
ColoTensor
(
2
,
3
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
lazy_t
=
ColoTensor
(
2
,
3
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
assert
lazy_t
.
_torch_tensor
.
numel
()
==
0
assert
lazy_t
.
_torch_tensor
.
numel
()
==
0
assert
lazy_t
.
torch_tensor
().
numel
()
==
6
assert
lazy_t
.
numel
()
==
6
==
lazy_t
.
torch_tensor
().
numel
()
def
check_all
():
def
check_all
():
test_linear
()
test_linear
()
...
@@ -70,5 +74,6 @@ def check_all():
...
@@ -70,5 +74,6 @@ def check_all():
test_no_wrap_op
()
test_no_wrap_op
()
test_lazy_init_tensor
()
test_lazy_init_tensor
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
check_all
()
test_lazy_init_tensor
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment