Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2ecc3d7a
"...Chat/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "268b3cd80d106c2b700156b1993675c7421abd15"
Unverified
Commit
2ecc3d7a
authored
Apr 21, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 21, 2022
Browse files
[tensor] lazy init (#823)
parent
68dcd51d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
8 deletions
+46
-8
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+34
-2
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+12
-6
No files found.
colossalai/tensor/colo_tensor.py
View file @
2ecc3d7a
import
torch
from
.op_wrapper
import
_COLOSSAL_OPS
from
typing
import
Tuple
class
ColoTensor
(
object
):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
return
super
(
ColoTensor
,
cls
).
__new__
(
cls
)
def
__init__
(
self
,
t
:
torch
.
Tensor
)
->
None
:
self
.
_torch_tensor
=
t
def
__init__
(
self
,
*
size
:
Tuple
[
int
],
dtype
=
None
,
requires_grad
=
False
,
pin_memory
=
False
,
torch_tensor
=
None
,
):
self
.
_size
=
size
self
.
_dtype
=
dtype
self
.
_requires_grad
=
requires_grad
self
.
_pin_memory
=
pin_memory
self
.
_torch_tensor
=
torch_tensor
@
staticmethod
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
):
colo_t
=
ColoTensor
(
*
tensor
.
size
(),
dtype
=
tensor
.
dtype
,
requires_grad
=
tensor
.
requires_grad
,
pin_memory
=
tensor
.
pin_memory
,
torch_tensor
=
tensor
)
return
colo_t
def
torch_tensor
(
self
)
->
torch
.
Tensor
:
if
self
.
_torch_tensor
==
None
:
self
.
_torch_tensor
=
torch
.
empty
(
*
self
.
_size
,
dtype
=
self
.
_dtype
,
requires_grad
=
self
.
_requires_grad
,
pin_memory
=
self
.
_pin_memory
)
return
self
.
_torch_tensor
@
classmethod
...
...
tests/test_tensor/test_op.py
View file @
2ecc3d7a
from
numpy
import
allclose
from
numpy
import
allclose
,
require
import
torch
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
...
...
@@ -14,8 +14,8 @@ def test_linear():
input_ref
=
torch
.
randn
(
1
,
in_dim
)
input_tensor
=
input_ref
.
clone
()
sharded_weight
=
ColoTensor
(
fc_ref
.
weight
)
sharded_bias
=
ColoTensor
(
fc_ref
.
bias
)
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
fc_ref
.
weight
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
fc_ref
.
bias
)
# replace the torch nn.Parameters with ShardedTensor
delattr
(
fc
,
'weight'
)
...
...
@@ -48,7 +48,7 @@ def test_linear():
def
test_element_wise
():
t_ref
=
torch
.
randn
(
3
,
5
)
t
=
ColoTensor
(
t_ref
.
clone
())
t
=
ColoTensor
.
init_from_torch_tensor
(
t_ref
.
clone
())
assert
torch
.
mean
(
t
)
==
torch
.
mean
(
t_ref
)
assert
allclose
(
torch
.
nn
.
functional
.
gelu
(
t
),
torch
.
nn
.
functional
.
gelu
(
t_ref
))
assert
allclose
(
torch
.
nn
.
functional
.
relu
(
t
),
torch
.
nn
.
functional
.
relu
(
t_ref
))
...
...
@@ -57,10 +57,16 @@ def test_element_wise():
# Test a function not wrapped by
def
test_no_wrap_op
():
t_ref
=
torch
.
randn
(
3
,
5
)
t
=
ColoTensor
(
t_ref
.
clone
())
t
=
ColoTensor
.
init_from_torch_tensor
(
t_ref
.
clone
())
assert
torch
.
sum
(
t
)
==
torch
.
sum
(
t_ref
)
def
test_lazy_init_tensor
():
lazy_t
=
ColoTensor
((
2
,
3
),
dtype
=
torch
.
float32
,
requires_grad
=
True
)
assert
lazy_t
.
_torch_tensor
==
None
assert
lazy_t
.
torch_tensor
().
numel
()
==
6
if
__name__
==
'__main__'
:
test_
no_wrap_op
()
test_
lazy_init_tensor
()
# test_element_wise()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment