Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4ca73234
Unverified
Commit
4ca73234
authored
May 10, 2022
by
ver217
Committed by
GitHub
May 10, 2022
Browse files
[tensor] colo tensor overrides mul (#927)
* colo tensor overrides mul * polish code
parent
45b9124d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
1 deletion
+15
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+15
-1
No files found.
colossalai/tensor/colo_tensor.py
View file @
4ca73234
...
...
@@ -232,16 +232,20 @@ class ColoTensor(object):
def
__add__
(
self
,
o
)
->
"ColoTensor"
:
if
isinstance
(
o
,
ColoTensor
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
elif
isinstance
(
o
,
torch
.
Tensor
):
elif
isinstance
(
o
,
(
torch
.
Tensor
,
int
,
float
)
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
)
else
:
raise
TypeError
(
f
'
{
type
(
o
)
}
is not supported in ColoTensor __add__'
)
__radd__
=
__add__
def
__truediv__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
/
o
)
def
__getattr__
(
self
,
name
):
def
replace_tensor_with_colo
(
func
):
def
execute_func
(
*
args
,
**
kwargs
):
# transform the ColoTensor args to torch Tensor.
args
=
[
arg
.
torch_tensor
()
if
isinstance
(
arg
,
ColoTensor
)
else
arg
for
arg
in
args
]
...
...
@@ -282,3 +286,13 @@ class ColoTensor(object):
else
:
raise
NotImplementedError
return
dim
def
__mul__
(
self
,
other
)
->
"ColoTensor"
:
if
isinstance
(
other
,
ColoTensor
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
*
other
.
torch_tensor
())
elif
isinstance
(
other
,
(
torch
.
Tensor
,
int
,
float
)):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
*
other
)
else
:
raise
TypeError
(
f
'
{
type
(
other
)
}
is not supported in ColoTensor __mul__'
)
__rmul__
=
__mul__
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment