Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f99f56df
"examples/vscode:/vscode.git/clone" did not exist on "cf4792c9757e071217f0b99f4e2bcc85f2d048b7"
Unverified
Commit
f99f56df
authored
Jun 15, 2022
by
ver217
Committed by
GitHub
Jun 15, 2022
Browse files
fix colo parameter torch function (#1117)
parent
e1620dda
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
4 deletions
+22
-4
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+22
-4
No files found.
colossalai/tensor/colo_parameter.py
View file @
f99f56df
...
...
@@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager
from
typing
import
Optional
def
filter_args
(
func
,
*
args
):
return
[
arg
for
arg
in
args
if
func
(
arg
)]
def
unpack_args
(
*
args
):
if
len
(
args
)
==
1
:
return
args
[
0
]
return
args
def
replace_args
(
args
,
kwargs
,
new_args
):
args
=
new_args
[:
len
(
args
)]
for
k
,
v
in
zip
(
kwargs
.
keys
(),
new_args
[
len
(
args
):]):
kwargs
[
k
]
=
v
return
unpack_args
(
args
),
kwargs
class
ColoParameter
(
ColoTensor
,
torch
.
nn
.
Parameter
):
r
"""A kind of ColoTensor to be considered as a module parameter.
...
...
@@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
if
ParamOpHookManager
.
has_hook
():
if
not
func
.
__name__
.
startswith
(
'__'
):
params
=
list
(
filter
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
args
))
if
kwargs
is
not
None
:
params
.
extend
(
list
(
filter
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
kwargs
.
values
())
))
if
kwargs
is
None
:
kwargs
=
{}
params
=
filter
_args
(
lambda
arg
:
isinstance
(
arg
,
ColoParameter
),
*
args
,
*
kwargs
.
values
())
if
len
(
params
)
>
0
:
with
torch
.
_C
.
DisableTorchFunction
():
args
=
ParamOpHookManager
.
pre_op
(
params
,
*
args
)
new_args
=
ParamOpHookManager
.
pre_op
(
params
,
*
args
,
*
kwargs
.
values
())
args
,
kwargs
=
replace_args
(
args
,
kwargs
,
new_args
)
ret
=
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
ParamOpHookManager
.
post_op
(
params
,
ret
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment