Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fa3d66fe
Unverified
Commit
fa3d66fe
authored
Feb 02, 2023
by
oahzxl
Committed by
GitHub
Feb 02, 2023
Browse files
support unet metainfo prop (#2544)
parent
c4b15661
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
17 deletions
+35
-17
colossalai/fx/_meta_registrations.py
colossalai/fx/_meta_registrations.py
+14
-17
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+21
-0
No files found.
colossalai/fx/_meta_registrations.py
View file @
fa3d66fe
...
@@ -164,18 +164,9 @@ def meta_conv(
...
@@ -164,18 +164,9 @@ def meta_conv(
@
register_meta
(
aten
.
_convolution
.
default
)
@
register_meta
(
aten
.
_convolution
.
default
)
def
meta_conv_1
(
def
meta_conv_1
(
input_tensor
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
stride
:
List
[
int
],
input_tensor
:
torch
.
Tensor
,
padding
:
List
[
int
],
dilation
:
List
[
int
],
is_transposed
:
bool
,
output_padding
:
List
[
int
],
groups
:
int
,
weight
:
torch
.
Tensor
,
*
extra_args
):
bias
:
torch
.
Tensor
,
stride
:
List
[
int
],
padding
:
List
[
int
],
dilation
:
List
[
int
],
is_transposed
:
bool
,
output_padding
:
List
[
int
],
groups
:
int
,
*
extra_args
):
out
=
meta_conv
(
input_tensor
,
weight
,
bias
,
stride
,
padding
,
dilation
,
is_transposed
,
output_padding
,
groups
)
out
=
meta_conv
(
input_tensor
,
weight
,
bias
,
stride
,
padding
,
dilation
,
is_transposed
,
output_padding
,
groups
)
return
out
return
out
...
@@ -233,11 +224,8 @@ def meta_cuda_rnn(
...
@@ -233,11 +224,8 @@ def meta_cuda_rnn(
if
is_input_packed
:
if
is_input_packed
:
out_shape
=
[
batch_sizes_sum
,
out_size
*
num_directions
]
out_shape
=
[
batch_sizes_sum
,
out_size
*
num_directions
]
else
:
else
:
out_shape
=
(
out_shape
=
([
mini_batch
,
seq_length
,
out_size
*
[
mini_batch
,
seq_length
,
out_size
*
num_directions
]
num_directions
]
if
batch_first
else
[
seq_length
,
mini_batch
,
out_size
*
num_directions
])
if
batch_first
else
[
seq_length
,
mini_batch
,
out_size
*
num_directions
]
)
output
=
input
.
new_empty
(
out_shape
)
output
=
input
.
new_empty
(
out_shape
)
cell_shape
=
[
num_layers
*
num_directions
,
mini_batch
,
hidden_size
]
cell_shape
=
[
num_layers
*
num_directions
,
mini_batch
,
hidden_size
]
...
@@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
...
@@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return
dX
,
dgamma
,
dbeta
return
dX
,
dgamma
,
dbeta
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp
@
register_meta
(
aten
.
native_group_norm_backward
.
default
)
def
meta_gn_backward
(
dY
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
mean
,
rstd
,
gamma
,
N
,
C
,
HxW
,
group
,
grad_input_mask
):
dX
=
torch
.
empty_like
(
input
)
dgamma
=
torch
.
empty_like
(
gamma
)
dbeta
=
torch
.
empty_like
(
gamma
)
return
dX
,
dgamma
,
dbeta
# ================================== Misc ==========================================
# ================================== Misc ==========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@
register_meta
(
aten
.
roll
.
default
)
@
register_meta
(
aten
.
roll
.
default
)
...
...
colossalai/fx/profiler/opcount.py
View file @
fa3d66fe
...
@@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
...
@@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
return
flops
return
flops
def
baddbmm_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
"""
Count flops for the baddbmm(batch add and batch matmul) operation.
"""
# Inputs = [input, batch1, batch2]
# out = input + batch1 x batch2
assert
len
(
inputs
)
==
3
,
len
(
inputs
)
n
,
c
,
t
=
inputs
[
1
].
shape
d
=
inputs
[
2
].
shape
[
-
1
]
flops
=
n
*
c
*
t
*
d
return
flops
def
conv_flop_count
(
def
conv_flop_count
(
x_shape
:
List
[
int
],
x_shape
:
List
[
int
],
w_shape
:
List
[
int
],
w_shape
:
List
[
int
],
...
@@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
...
@@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
matmul
.
default
:
matmul_flop_jit
,
aten
.
matmul
.
default
:
matmul_flop_jit
,
aten
.
addmm
.
default
:
addmm_flop_jit
,
aten
.
addmm
.
default
:
addmm_flop_jit
,
aten
.
bmm
.
default
:
bmm_flop_jit
,
aten
.
bmm
.
default
:
bmm_flop_jit
,
aten
.
baddbmm
.
default
:
baddbmm_flop_jit
,
# convolution
# convolution
aten
.
convolution
.
default
:
conv_flop_jit
,
aten
.
convolution
.
default
:
conv_flop_jit
,
...
@@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
...
@@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
cudnn_batch_norm_backward
.
default
:
partial
(
batchnorm_flop_jit
,
training
=
True
),
aten
.
cudnn_batch_norm_backward
.
default
:
partial
(
batchnorm_flop_jit
,
training
=
True
),
aten
.
native_layer_norm
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_layer_norm
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_layer_norm_backward
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_layer_norm_backward
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_group_norm
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_group_norm_backward
.
default
:
norm_flop_counter
(
2
,
0
),
# pooling
# pooling
aten
.
avg_pool1d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
avg_pool1d
.
default
:
elementwise_flop_counter
(
1
,
0
),
...
@@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
...
@@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
embedding
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
upsample_nearest2d
.
vec
:
elementwise_flop_counter
(
0
,
1
),
aten
.
upsample_nearest2d_backward
.
vec
:
elementwise_flop_counter
(
0
,
1
),
}
}
elementwise_flop_aten
=
[
elementwise_flop_aten
=
[
...
@@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
...
@@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
mean
.
dim
,
aten
.
mean
.
dim
,
aten
.
sub
.
Tensor
,
aten
.
sub
.
Tensor
,
aten
.
sub_
.
Tensor
,
aten
.
sub_
.
Tensor
,
aten
.
exp
.
default
,
aten
.
sin
.
default
,
aten
.
cos
.
default
,
# activation op
# activation op
aten
.
hardswish
.
default
,
aten
.
hardswish
.
default
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment