Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2c8c0567
Unverified
Commit
2c8c0567
authored
Jun 29, 2022
by
Frank Lee
Committed by
GitHub
Jun 29, 2022
Browse files
[fx] patched conv and normalization (#1188)
parent
6f0733a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
308 additions
and
0 deletions
+308
-0
colossalai/fx/tracer/meta_patch/patched_module.py
colossalai/fx/tracer/meta_patch/patched_module.py
+81
-0
tests/test_fx/test_tracer/test_patched_module.py
tests/test_fx/test_tracer/test_patched_module.py
+227
-0
No files found.
colossalai/fx/tracer/meta_patch/patched_module.py
View file @
2c8c0567
import
math
import
torch
from
.registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Linear
)
def
torch_nn_linear
(
self
,
input
):
last_dim
=
input
.
shape
[
-
1
]
assert
last_dim
==
self
.
in_features
,
f
'Expected hidden size
{
self
.
in_features
}
but got
{
last_dim
}
for the torch.nn.Linear patch'
return
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
self
.
out_features
,),
device
=
"meta"
)
@
meta_patched_module
.
register
(
torch
.
nn
.
LayerNorm
)
@
meta_patched_module
.
register
(
torch
.
nn
.
GroupNorm
)
@
meta_patched_module
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
BatchNorm3d
)
def
torch_nn_normalize
(
self
,
input
):
# check shape
if
isinstance
(
self
,
torch
.
nn
.
BatchNorm1d
):
assert
input
.
dim
()
in
[
2
,
3
]
elif
isinstance
(
self
,
torch
.
nn
.
BatchNorm2d
):
assert
input
.
dim
()
==
4
elif
isinstance
(
self
,
torch
.
nn
.
BatchNorm3d
):
assert
input
.
dim
()
==
5
# normalization maintain the same shape as the input
return
input
.
clone
()
@
meta_patched_module
.
register
(
torch
.
nn
.
Embedding
)
def
torch_nn_embedding
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
1
]
+
(
self
.
embedding_dim
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv1d
)
def
torch_nn_conv1d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in
=
input
.
shape
[
-
1
]
c_out
=
self
.
out_channels
l_out
=
math
.
floor
((
l_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv2d
)
def
torch_nn_conv2d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
self
.
out_channels
h_out
=
math
.
floor
((
h_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
self
.
padding
[
1
]
-
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
-
1
)
/
self
.
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv3d
)
def
torch_nn_conv3d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
self
.
out_channels
d_out
=
math
.
floor
((
d_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
self
.
padding
[
1
]
-
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
-
1
)
/
self
.
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
self
.
padding
[
2
]
-
self
.
dilation
[
2
]
*
(
self
.
kernel_size
[
2
]
-
1
)
-
1
)
/
self
.
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
tests/test_fx/test_tracer/test_patched_module.py
0 → 100644
View file @
2c8c0567
import
torch
from
colossalai.fx.tracer.meta_patch
import
patched_module
def
_run
(
data
,
module
,
patch_fn
):
try
:
output
=
patch_fn
(
module
,
data
)
return
output
except
Exception
as
e
:
return
e
def
_assert_output_shape
(
data
,
module
,
patch_fn
,
expect_exception
,
output_shape
):
output
=
_run
(
data
,
module
,
patch_fn
)
if
expect_exception
:
assert
isinstance
(
output
,
AssertionError
)
else
:
assert
not
isinstance
(
output
,
Exception
)
assert
output
.
is_meta
assert
output
.
shape
==
output_shape
def
test_linear
():
# test linear patch can produce the meta output with correct shape
data
=
torch
.
rand
(
2
,
4
,
device
=
'meta'
)
module
=
torch
.
nn
.
Linear
(
4
,
2
)
_assert_output_shape
(
data
,
module
,
patched_module
.
torch_nn_linear
,
False
,
torch
.
Size
([
2
,
2
]))
# Test if the linear patch can catch exception when dimension does not match
data
=
torch
.
rand
(
2
,
2
,
device
=
'meta'
)
_assert_output_shape
(
data
,
module
,
patched_module
.
torch_nn_linear
,
True
,
None
)
def
test_embedding
():
data
=
torch
.
rand
(
2
,
4
,
device
=
'meta'
)
# test layernorm
ln
=
torch
.
nn
.
LayerNorm
(
4
)
_assert_output_shape
(
data
,
ln
,
patched_module
.
torch_nn_normalize
,
False
,
data
.
shape
)
# test group norm
gn
=
torch
.
nn
.
GroupNorm
(
4
,
num_channels
=
2
)
_assert_output_shape
(
data
,
gn
,
patched_module
.
torch_nn_normalize
,
False
,
data
.
shape
)
# test batch norm 1d
bn1d
=
torch
.
nn
.
BatchNorm1d
(
4
)
data
=
torch
.
rand
(
2
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn1d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
False
,
output_shape
=
data
.
shape
)
data
=
torch
.
rand
(
2
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn1d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
False
,
output_shape
=
data
.
shape
)
data
=
torch
.
rand
(
2
,
3
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn1d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
False
,
output_shape
=
data
.
shape
)
data
=
torch
.
rand
(
1
,
2
,
3
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn1d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
True
,
output_shape
=
None
)
# test batch norm 2d
bn2d
=
torch
.
nn
.
BatchNorm2d
(
4
)
data
=
torch
.
rand
(
1
,
2
,
3
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn2d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
False
,
output_shape
=
data
.
shape
)
data
=
torch
.
rand
(
2
,
3
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn2d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
True
,
output_shape
=
None
)
# # test batch size 3d
bn3d
=
torch
.
nn
.
BatchNorm3d
(
4
)
data
=
torch
.
rand
(
1
,
1
,
2
,
3
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn3d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
False
,
output_shape
=
data
.
shape
)
data
=
torch
.
rand
(
1
,
2
,
3
,
4
,
device
=
'meta'
)
_assert_output_shape
(
data
=
data
,
module
=
bn3d
,
patch_fn
=
patched_module
.
torch_nn_normalize
,
expect_exception
=
True
,
output_shape
=
None
)
def
test_conv1d
():
# test conv 1d
data
=
torch
.
rand
(
2
,
3
,
4
)
conv1d
=
torch
.
nn
.
Conv1d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
)
materialized_output
=
conv1d
(
data
)
meta_data
=
data
.
to
(
'meta'
)
_assert_output_shape
(
data
=
meta_data
,
module
=
conv1d
,
patch_fn
=
patched_module
.
torch_nn_conv1d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv1d
=
torch
.
nn
.
Conv1d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
)
materialized_output
=
conv1d
(
data
)
meta_data
=
data
.
to
(
'meta'
)
_assert_output_shape
(
data
=
meta_data
,
module
=
conv1d
,
patch_fn
=
patched_module
.
torch_nn_conv1d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv1d
=
torch
.
nn
.
Conv1d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
,
dilation
=
2
,
padding_mode
=
'reflect'
)
materialized_output
=
conv1d
(
data
)
meta_data
=
data
.
to
(
'meta'
)
_assert_output_shape
(
data
=
meta_data
,
module
=
conv1d
,
patch_fn
=
patched_module
.
torch_nn_conv1d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
def
test_conv2d
():
# test conv 1d
data
=
torch
.
rand
(
2
,
3
,
4
,
4
)
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
)
materialized_output
=
conv2d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv2d
,
patch_fn
=
patched_module
.
torch_nn_conv2d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
)
materialized_output
=
conv2d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv2d
,
patch_fn
=
patched_module
.
torch_nn_conv2d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
,
dilation
=
2
)
materialized_output
=
conv2d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv2d
,
patch_fn
=
patched_module
.
torch_nn_conv2d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
,
dilation
=
2
,
padding_mode
=
'reflect'
)
materialized_output
=
conv2d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv2d
,
patch_fn
=
patched_module
.
torch_nn_conv2d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
def
test_conv3d
():
# test conv 1d
data
=
torch
.
rand
(
2
,
3
,
4
,
4
,
4
)
conv3d
=
torch
.
nn
.
Conv3d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
)
materialized_output
=
conv3d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv3d
,
patch_fn
=
patched_module
.
torch_nn_conv3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv3d
=
torch
.
nn
.
Conv3d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
)
materialized_output
=
conv3d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv3d
,
patch_fn
=
patched_module
.
torch_nn_conv3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv3d
=
torch
.
nn
.
Conv3d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
,
dilation
=
2
)
materialized_output
=
conv3d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv3d
,
patch_fn
=
patched_module
.
torch_nn_conv3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
conv3d
=
torch
.
nn
.
Conv3d
(
in_channels
=
3
,
out_channels
=
4
,
kernel_size
=
2
,
padding
=
1
,
dilation
=
2
,
padding_mode
=
'reflect'
)
materialized_output
=
conv3d
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
conv3d
,
patch_fn
=
patched_module
.
torch_nn_conv3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment