Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
02b05803
Unverified
Commit
02b05803
authored
Mar 27, 2023
by
HELSON
Committed by
GitHub
Mar 27, 2023
Browse files
[fx] meta registration compatibility (#3253)
* [fx] meta registration compatibility * fix error
parent
73d3e4d3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
71 additions
and
4 deletions
+71
-4
colossalai/fx/_compatibility.py
colossalai/fx/_compatibility.py
+14
-4
colossalai/fx/_meta_regist_12.py
colossalai/fx/_meta_regist_12.py
+0
-0
colossalai/fx/_meta_regist_13.py
colossalai/fx/_meta_regist_13.py
+57
-0
No files found.
colossalai/fx/_compatibility.py
View file @
02b05803
...
@@ -2,11 +2,21 @@ from typing import Callable
...
@@ -2,11 +2,21 @@ from typing import Callable
import
torch
import
torch
try
:
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
from
.
import
_meta_registrations
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
META_COMPATIBILITY
=
True
except
:
if
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
12
:
META_COMPATIBILITY
=
False
META_COMPATIBILITY
=
False
elif
TORCH_MAJOR
==
1
and
TORCH_MINOR
==
12
:
from
.
import
_meta_regist_12
META_COMPATIBILITY
=
True
elif
TORCH_MAJOR
==
1
and
TORCH_MINOR
==
13
:
from
.
import
_meta_regist_13
META_COMPATIBILITY
=
True
elif
TORCH_MAJOR
==
2
:
from
.
import
_meta_regist_13
META_COMPATIBILITY
=
True
raise
UserWarning
(
"Colossalai is not tested with torch2.0 yet!!!"
)
def
compatibility
(
is_backward_compatible
:
bool
=
False
)
->
Callable
:
def
compatibility
(
is_backward_compatible
:
bool
=
False
)
->
Callable
:
...
...
colossalai/fx/_meta_regist
rations
.py
→
colossalai/fx/_meta_regist
_12
.py
View file @
02b05803
File moved
colossalai/fx/_meta_regist_13.py
0 → 100644
View file @
02b05803
import
torch
from
torch._meta_registrations
import
register_meta
from
torch._prims_common
import
check
aten
=
torch
.
ops
.
aten
# since we fix the torch version to 1.13.1, we have to add unimplemented meta ops
# all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
@
register_meta
([
aten
.
convolution_backward
.
default
])
def
meta_convolution_backward
(
grad_output_
,
input_
,
weight_
,
bias_sizes_opt
,
stride
,
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
output_mask
,
):
# High level logic taken from slow_conv3d_backward_cpu which should
# be representative of all convolution_backward impls
backend_grad_input
=
None
backend_grad_weight
=
None
backend_grad_bias
=
None
if
output_mask
[
0
]:
backend_grad_input
=
grad_output_
.
new_empty
(
input_
.
size
())
if
output_mask
[
1
]:
backend_grad_weight
=
grad_output_
.
new_empty
(
weight_
.
size
())
if
output_mask
[
2
]:
backend_grad_bias
=
grad_output_
.
new_empty
(
bias_sizes_opt
)
return
(
backend_grad_input
,
backend_grad_weight
,
backend_grad_bias
)
@
register_meta
(
aten
.
_adaptive_avg_pool2d_backward
.
default
)
def
meta__adaptive_avg_pool2d_backward
(
grad_out
,
self
):
ndim
=
grad_out
.
ndim
for
i
in
range
(
1
,
ndim
):
check
(
grad_out
.
size
(
i
)
>
0
,
lambda
:
f
"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero
\
size for non-batch dimensions,
{
grad_out
.
shape
}
with dimension
{
i
}
being empty"
,
)
check
(
ndim
==
3
or
ndim
==
4
,
lambda
:
f
"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got
{
self
.
shape
}
"
,
)
check
(
self
.
dtype
==
grad_out
.
dtype
,
lambda
:
f
"expected dtype
{
self
.
dtype
}
for `grad_output` but got dtype
{
grad_out
.
dtype
}
"
,
)
return
self
.
new_empty
(
self
.
shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment