Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4b3d6cae
Unverified
Commit
4b3d6cae
authored
Sep 01, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 01, 2022
Browse files
[fx]patch nn.functional convolution (#1528)
parent
5156d5b4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
228 additions
and
1 deletion
+228
-1
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
+2
-1
colossalai/fx/tracer/meta_patch/patched_function/convolution.py
...alai/fx/tracer/meta_patch/patched_function/convolution.py
+178
-0
tests/test_fx/test_tracer/test_functional_conv.py
tests/test_fx/test_tracer/test_functional_conv.py
+48
-0
No files found.
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
View file @
4b3d6cae
...
@@ -3,4 +3,5 @@ from .arithmetic import *
...
@@ -3,4 +3,5 @@ from .arithmetic import *
from
.embedding
import
*
from
.embedding
import
*
from
.normalization
import
*
from
.normalization
import
*
from
.python_ops
import
*
from
.python_ops
import
*
from
.torch_ops
import
*
from
.torch_ops
import
*
\ No newline at end of file
from
.convolution
import
*
\ No newline at end of file
colossalai/fx/tracer/meta_patch/patched_function/convolution.py
0 → 100644
View file @
4b3d6cae
import
torch
import
collections
from
itertools
import
repeat
from
..registry
import
meta_patched_function
import
math
def
_ntuple
(
n
,
name
=
"parse"
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
tuple
(
x
)
return
tuple
(
repeat
(
x
,
n
))
parse
.
__name__
=
name
return
parse
_single
=
_ntuple
(
1
,
"_single"
)
_pair
=
_ntuple
(
2
,
"_pair"
)
_triple
=
_ntuple
(
3
,
"_triple"
)
def
_extract_kwargs
(
kwargs
):
if
'stride'
in
kwargs
:
stride
=
kwargs
[
'stride'
]
else
:
stride
=
1
# TODO: process str type padding
if
'padding'
in
kwargs
:
padding
=
kwargs
[
'padding'
]
else
:
padding
=
0
if
'dilation'
in
kwargs
:
dilation
=
kwargs
[
'dilation'
]
else
:
dilation
=
1
if
'output_padding'
in
kwargs
:
output_padding
=
kwargs
[
'output_padding'
]
else
:
output_padding
=
0
return
stride
,
padding
,
dilation
,
output_padding
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv1d
)
def
torch_nn_functional_conv1d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
_
=
_extract_kwargs
(
kwargs
)
stride
=
_single
(
stride
)
padding
=
_single
(
padding
)
dilation
=
_single
(
dilation
)
kernel_size
=
weight
.
shape
[
2
:]
l_in
=
input
.
shape
[
-
1
]
c_out
=
weight
.
shape
[
0
]
l_out
=
math
.
floor
((
l_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv2d
)
def
torch_nn_functional_conv2d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
_
=
_extract_kwargs
(
kwargs
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
dilation
=
_pair
(
dilation
)
kernel_size
=
weight
.
shape
[
2
:]
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
weight
.
shape
[
0
]
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv3d
)
def
torch_nn_functional_conv3d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
_
=
_extract_kwargs
(
kwargs
)
stride
=
_triple
(
stride
)
padding
=
_triple
(
padding
)
dilation
=
_triple
(
dilation
)
kernel_size
=
weight
.
shape
[
2
:]
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
weight
.
shape
[
0
]
d_out
=
math
.
floor
((
d_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
-
1
)
/
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv_transpose1d
)
def
torch_nn_functional_convtranspose1d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
output_padding
=
_extract_kwargs
(
kwargs
)
stride
=
_single
(
stride
)
padding
=
_single
(
padding
)
dilation
=
_single
(
dilation
)
output_padding
=
_single
(
output_padding
)
kernel_size
=
weight
.
shape
[
2
:]
l_in
=
input
.
shape
[
-
1
]
c_out
=
weight
.
shape
[
1
]
l_out
=
math
.
floor
((
l_in
-
1
)
*
stride
[
0
]
-
2
*
padding
[
0
]
+
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
+
output_padding
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv_transpose2d
)
def
torch_nn_functional_convtranspose2d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
output_padding
=
_extract_kwargs
(
kwargs
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
dilation
=
_pair
(
dilation
)
output_padding
=
_pair
(
output_padding
)
kernel_size
=
weight
.
shape
[
2
:]
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
weight
.
shape
[
1
]
h_out
=
math
.
floor
((
h_in
-
1
)
*
stride
[
0
]
-
2
*
padding
[
0
]
+
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
+
output_padding
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
stride
[
1
]
-
2
*
padding
[
1
]
+
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
+
output_padding
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv_transpose3d
)
def
torch_nn_functional_convtranspose3d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
output_padding
=
_extract_kwargs
(
kwargs
)
stride
=
_triple
(
stride
)
padding
=
_triple
(
padding
)
dilation
=
_triple
(
dilation
)
output_padding
=
_triple
(
output_padding
)
kernel_size
=
weight
.
shape
[
2
:]
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
weight
.
shape
[
1
]
d_out
=
math
.
floor
((
d_in
-
1
)
*
stride
[
0
]
-
2
*
padding
[
0
]
+
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
+
output_padding
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
-
1
)
*
stride
[
1
]
-
2
*
padding
[
1
]
+
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
+
output_padding
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
stride
[
2
]
-
2
*
padding
[
2
]
+
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
+
output_padding
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
tests/test_fx/test_tracer/test_functional_conv.py
0 → 100644
View file @
4b3d6cae
import
torch
from
torch.nn
import
functional
as
F
from
colossalai.fx.tracer.meta_patch
import
patched_function
def
test_conv
():
# test F.conv_1d
data_1d
=
torch
.
rand
(
3
,
16
,
10
)
weight_1d
=
torch
.
rand
(
3
,
16
,
3
)
out_1d
=
F
.
conv1d
(
data_1d
,
weight_1d
)
patched_out_1d
=
patched_function
.
torch_nn_functional_conv1d
(
data_1d
,
weight_1d
)
assert
out_1d
.
shape
==
patched_out_1d
.
shape
# test F.conv_transpose1d
weight_1d
=
torch
.
transpose
(
weight_1d
,
0
,
1
)
out_transpose_1d
=
F
.
conv_transpose1d
(
data_1d
,
weight_1d
)
patched_out_transpose_1d
=
patched_function
.
torch_nn_functional_convtranspose1d
(
data_1d
,
weight_1d
)
assert
out_transpose_1d
.
shape
==
patched_out_transpose_1d
.
shape
# test F.conv2d
data_2d
=
torch
.
rand
(
3
,
16
,
10
,
10
)
weight_2d
=
torch
.
rand
(
3
,
16
,
3
,
3
)
out_2d
=
F
.
conv2d
(
data_2d
,
weight_2d
)
patched_out_2d
=
patched_function
.
torch_nn_functional_conv2d
(
data_2d
,
weight_2d
)
assert
out_2d
.
shape
==
patched_out_2d
.
shape
# test F.conv_transpose2d
weight_2d
=
torch
.
transpose
(
weight_2d
,
0
,
1
)
out_transpose_2d
=
F
.
conv_transpose2d
(
data_2d
,
weight_2d
)
patched_out_transpose_2d
=
patched_function
.
torch_nn_functional_convtranspose2d
(
data_2d
,
weight_2d
)
assert
out_transpose_2d
.
shape
==
patched_out_transpose_2d
.
shape
# test F.conv3d
data_3d
=
torch
.
rand
(
3
,
16
,
10
,
10
,
10
)
weight_3d
=
torch
.
rand
(
3
,
16
,
3
,
3
,
3
)
out_3d
=
F
.
conv3d
(
data_3d
,
weight_3d
)
patched_out_3d
=
patched_function
.
torch_nn_functional_conv3d
(
data_3d
,
weight_3d
)
assert
out_3d
.
shape
==
patched_out_3d
.
shape
# test F.conv_transpose3d
weight_3d
=
torch
.
transpose
(
weight_3d
,
0
,
1
)
out_transpose_3d
=
F
.
conv_transpose3d
(
data_3d
,
weight_3d
)
patched_out_transpose_3d
=
patched_function
.
torch_nn_functional_convtranspose3d
(
data_3d
,
weight_3d
)
assert
out_transpose_3d
.
shape
==
patched_out_transpose_3d
.
shape
if
__name__
==
'__main__'
:
test_conv
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment