Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
08f2920e
Commit
08f2920e
authored
Apr 23, 2023
by
zhuwenwen
Browse files
init colossalai, support dtk2304
parent
da3f0934
Pipeline
#237
failed with stages
in 0 seconds
Changes
380
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1094 additions
and
0 deletions
+1094
-0
colossalai/fx/profiler/experimental/profiler_function/linear.py
...alai/fx/profiler/experimental/profiler_function/linear.py
+13
-0
colossalai/fx/profiler/experimental/profiler_function/normalization.py
.../profiler/experimental/profiler_function/normalization.py
+66
-0
colossalai/fx/profiler/experimental/profiler_function/pooling.py
...lai/fx/profiler/experimental/profiler_function/pooling.py
+22
-0
colossalai/fx/profiler/experimental/profiler_function/python_ops.py
.../fx/profiler/experimental/profiler_function/python_ops.py
+18
-0
colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
...i/fx/profiler/experimental/profiler_function/torch_ops.py
+60
-0
colossalai/fx/profiler/experimental/profiler_module/__init__.py
...alai/fx/profiler/experimental/profiler_module/__init__.py
+10
-0
colossalai/fx/profiler/experimental/profiler_module/activation_function.py
...filer/experimental/profiler_module/activation_function.py
+33
-0
colossalai/fx/profiler/experimental/profiler_module/attention.py
...lai/fx/profiler/experimental/profiler_module/attention.py
+81
-0
colossalai/fx/profiler/experimental/profiler_module/convolution.py
...i/fx/profiler/experimental/profiler_module/convolution.py
+152
-0
colossalai/fx/profiler/experimental/profiler_module/dropout.py
...salai/fx/profiler/experimental/profiler_module/dropout.py
+11
-0
colossalai/fx/profiler/experimental/profiler_module/embedding.py
...lai/fx/profiler/experimental/profiler_module/embedding.py
+11
-0
colossalai/fx/profiler/experimental/profiler_module/linear.py
...ssalai/fx/profiler/experimental/profiler_module/linear.py
+14
-0
colossalai/fx/profiler/experimental/profiler_module/normalization.py
...fx/profiler/experimental/profiler_module/normalization.py
+33
-0
colossalai/fx/profiler/experimental/profiler_module/pooling.py
...salai/fx/profiler/experimental/profiler_module/pooling.py
+22
-0
colossalai/fx/profiler/experimental/profiler_module/rnn.py
colossalai/fx/profiler/experimental/profiler_module/rnn.py
+75
-0
colossalai/fx/profiler/experimental/profiler_module/torch_op.py
...alai/fx/profiler/experimental/profiler_module/torch_op.py
+11
-0
colossalai/fx/profiler/experimental/registry.py
colossalai/fx/profiler/experimental/registry.py
+25
-0
colossalai/fx/profiler/experimental/shard_utils.py
colossalai/fx/profiler/experimental/shard_utils.py
+48
-0
colossalai/fx/profiler/memory_utils.py
colossalai/fx/profiler/memory_utils.py
+71
-0
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+318
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
colossalai/fx/profiler/experimental/profiler_function/linear.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_function
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
linear
)
def
torch_nn_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
=
None
)
->
Tuple
[
int
,
int
]:
out_features
=
weight
.
shape
[
0
]
macs
=
torch
.
numel
(
input
)
*
out_features
flops
=
2
*
macs
if
bias
is
not
None
:
flops
+=
bias
.
numel
()
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_function/normalization.py
0 → 100644
View file @
08f2920e
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
..registry
import
meta_profiler_function
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
instance_norm
)
def
torch_nn_func_instancenorm
(
input
:
torch
.
Tensor
,
running_mean
:
Optional
[
torch
.
Tensor
]
=
None
,
running_var
:
Optional
[
torch
.
Tensor
]
=
None
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_input_stats
:
bool
=
True
,
momentum
:
float
=
0.1
,
eps
:
float
=
1e-5
,
):
has_affine
=
weight
is
not
None
flops
=
input
.
numel
()
*
(
5
if
has_affine
else
4
)
macs
=
0
return
flops
,
macs
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
group_norm
)
def
torch_nn_func_groupnorm
(
input
:
torch
.
Tensor
,
num_groups
:
int
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-5
)
->
Tuple
[
int
,
int
]:
has_affine
=
weight
is
not
None
flops
=
input
.
numel
()
*
(
5
if
has_affine
else
4
)
macs
=
0
return
flops
,
macs
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
layer_norm
)
def
torch_nn_func_layernorm
(
input
:
torch
.
Tensor
,
normalized_shape
:
List
[
int
],
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-5
,
)
->
Tuple
[
int
,
int
]:
has_affine
=
weight
is
not
None
flops
=
input
.
numel
()
*
(
5
if
has_affine
else
4
)
macs
=
0
return
flops
,
macs
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
batch_norm
)
def
torch_nn_func_batchnorm
(
input
:
torch
.
Tensor
,
running_mean
:
Optional
[
torch
.
Tensor
],
running_var
:
Optional
[
torch
.
Tensor
],
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
training
:
bool
=
False
,
momentum
:
float
=
0.1
,
eps
:
float
=
1e-5
,
)
->
Tuple
[
int
,
int
]:
has_affine
=
weight
is
not
None
if
training
:
flops
=
input
.
numel
()
*
(
2
if
has_affine
else
1
)
else
:
flops
=
input
.
numel
()
*
(
5
if
has_affine
else
4
)
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_function/pooling.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
,
Union
import
torch
from
..registry
import
meta_profiler_function
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
avg_pool1d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
avg_pool2d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
avg_pool3d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
max_pool1d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
max_pool2d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
max_pool3d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
adaptive_avg_pool1d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
adaptive_avg_pool2d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
adaptive_avg_pool3d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
adaptive_max_pool1d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
adaptive_max_pool2d
)
@
meta_profiler_function
.
register
(
torch
.
nn
.
functional
.
adaptive_max_pool3d
)
def
torch_nn_func_pooling
(
input
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
Tuple
[
int
,
int
]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops
=
input
.
numel
()
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_function/python_ops.py
0 → 100644
View file @
08f2920e
import
operator
from
typing
import
Any
,
Tuple
import
torch
from
..registry
import
meta_profiler_function
@
meta_profiler_function
.
register
(
operator
.
getitem
)
def
operator_getitem
(
a
:
Any
,
b
:
Any
)
->
Tuple
[
int
,
int
]:
flops
=
0
macs
=
0
return
flops
,
macs
@
meta_profiler_function
.
register
(
getattr
)
def
python_getattr
(
a
:
Any
,
b
:
Any
)
->
Tuple
[
int
,
int
]:
flops
=
0
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
0 → 100644
View file @
08f2920e
from
functools
import
reduce
import
operator
from
typing
import
Any
,
Optional
,
Tuple
import
torch
from
..registry
import
meta_profiler_function
@
meta_profiler_function
.
register
(
torch
.
arange
)
@
meta_profiler_function
.
register
(
torch
.
finfo
)
@
meta_profiler_function
.
register
(
torch
.
permute
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
permute
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
repeat
)
@
meta_profiler_function
.
register
(
torch
.
index_select
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
index_select
)
@
meta_profiler_function
.
register
(
torch
.
squeeze
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
squeeze
)
@
meta_profiler_function
.
register
(
torch
.
unsqueeze
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
unsqueeze
)
@
meta_profiler_function
.
register
(
torch
.
cat
)
@
meta_profiler_function
.
register
(
torch
.
concat
)
@
meta_profiler_function
.
register
(
torch
.
repeat_interleave
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
repeat_interleave
)
@
meta_profiler_function
.
register
(
torch
.
flatten
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
flatten
)
@
meta_profiler_function
.
register
(
torch
.
roll
)
@
meta_profiler_function
.
register
(
torch
.
full
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
cpu
)
@
meta_profiler_function
.
register
(
torch
.
Tensor
.
cuda
)
@
meta_profiler_function
.
register
(
torch
.
_assert
)
def
torch_zero_flops_op
(
*
args
,
**
kwargs
)
->
Tuple
[
int
,
int
]:
flops
=
0
macs
=
0
return
flops
,
macs
@
meta_profiler_function
.
register
(
torch
.
where
)
def
torch_where
(
condition
:
torch
.
Tensor
,
x
:
Any
,
y
:
Any
)
->
Tuple
[
int
,
int
]:
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
flops
=
condition
.
numel
()
macs
=
0
return
flops
,
macs
@
meta_profiler_function
.
register
(
torch
.
max
)
def
torch_max
(
input
:
torch
.
Tensor
,
dim
:
int
=
None
,
keepdim
:
bool
=
False
,
*
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
int
,
int
]:
macs
=
0
assert
out
is
None
,
'assigning value to out is not supported yet'
if
dim
is
not
None
:
shape
=
list
(
input
.
shape
)
shape
.
pop
(
int
(
dim
))
flops
=
reduce
(
operator
.
mul
,
shape
),
macs
return
flops
,
macs
else
:
flops
=
input
.
numel
()
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/__init__.py
0 → 100644
View file @
08f2920e
from
.activation_function
import
*
from
.attention
import
*
from
.convolution
import
*
from
.dropout
import
*
from
.embedding
import
*
from
.linear
import
*
from
.normalization
import
*
from
.pooling
import
*
from
.rnn
import
*
from
.torch_op
import
*
colossalai/fx/profiler/experimental/profiler_module/activation_function.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
_multiplier
=
{
torch
.
nn
.
ReLU
:
1
,
torch
.
nn
.
PReLU
:
4
,
torch
.
nn
.
Sigmoid
:
4
,
torch
.
nn
.
Tanh
:
5
,
torch
.
nn
.
LeakyReLU
:
3
,
torch
.
nn
.
ELU
:
4
,
torch
.
nn
.
ReLU6
:
2
,
torch
.
nn
.
GELU
:
9
,
torch
.
nn
.
Hardswish
:
5
,
torch
.
nn
.
Hardsigmoid
:
4
,
}
@
meta_profiler_module
.
register
(
torch
.
nn
.
ELU
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
LeakyReLU
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
ReLU
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
GELU
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
Sigmoid
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
Tanh
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
ReLU6
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
PReLU
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
Hardswish
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
Hardsigmoid
)
def
torch_nn_non_linear_act
(
self
:
torch
.
nn
.
Module
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
flops
=
input
.
numel
()
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/attention.py
0 → 100644
View file @
08f2920e
from
typing
import
Optional
,
Tuple
import
torch
from
..registry
import
meta_profiler_module
# TODO: This is hard to compute memory cost
@
meta_profiler_module
.
register
(
torch
.
nn
.
MultiheadAttention
)
def
torch_nn_msa
(
self
:
torch
.
nn
.
MultiheadAttention
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
average_attn_weights
:
bool
=
True
)
->
Tuple
[
int
,
int
]:
if
getattr
(
self
,
'batch_first'
,
False
):
batch_size
=
query
.
shape
[
0
]
len_idx
=
1
else
:
batch_size
=
query
.
shape
[
1
]
len_idx
=
0
dim_idx
=
2
qdim
=
query
.
shape
[
dim_idx
]
kdim
=
key
.
shape
[
dim_idx
]
vdim
=
value
.
shape
[
dim_idx
]
qlen
=
query
.
shape
[
len_idx
]
klen
=
key
.
shape
[
len_idx
]
vlen
=
value
.
shape
[
len_idx
]
num_heads
=
self
.
num_heads
assert
qdim
==
self
.
embed_dim
if
self
.
kdim
is
None
:
assert
kdim
==
qdim
if
self
.
vdim
is
None
:
assert
vdim
==
qdim
flops
=
0
macs
=
0
# Q scaling
flops
+=
qlen
*
qdim
# Initial projections
flops
+=
2
*
((
qlen
*
qdim
*
qdim
)
# QW
+
(
klen
*
kdim
*
kdim
)
# KW
+
(
vlen
*
vdim
*
vdim
)
# VW
)
macs
+=
((
qlen
*
qdim
*
qdim
)
# QW
+
(
klen
*
kdim
*
kdim
)
# KW
+
(
vlen
*
vdim
*
vdim
)
# VW
)
if
self
.
in_proj_bias
is
not
None
:
flops
+=
(
qlen
+
klen
+
vlen
)
*
qdim
# attention heads: scale, matmul, softmax, matmul
qk_head_dim
=
qdim
//
num_heads
v_head_dim
=
vdim
//
num_heads
head_flops
=
(
2
*
(
qlen
*
klen
*
qk_head_dim
)
# QK^T
+
(
qlen
*
klen
)
# softmax
+
2
*
(
qlen
*
klen
*
v_head_dim
)
# AV
)
head_macs
=
((
qlen
*
klen
*
qk_head_dim
)
# QK^T
+
2
*
(
qlen
*
klen
*
v_head_dim
)
# AV
)
flops
+=
num_heads
*
head_flops
macs
+=
num_heads
*
head_flops
# final projection, bias is always enabled
flops
+=
qlen
*
vdim
*
(
vdim
+
1
)
flops
*=
batch_size
macs
*=
batch_size
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/convolution.py
0 → 100644
View file @
08f2920e
import
operator
from
functools
import
reduce
import
math
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_module
@
meta_profiler_module
.
register
(
torch
.
nn
.
Conv1d
)
def
torch_nn_conv1d
(
self
:
torch
.
nn
.
Conv1d
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in
,
l_in
=
input
.
shape
[
-
2
:]
c_out
=
self
.
out_channels
l_out
=
math
.
floor
((
l_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
macs_per_elem
=
reduce
(
operator
.
mul
,
self
.
kernel_size
)
*
c_in
//
self
.
groups
num_elem
=
reduce
(
operator
.
mul
,
result_shape
)
macs
=
macs_per_elem
*
num_elem
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
num_elem
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
Conv2d
)
def
torch_nn_conv2d
(
self
:
torch
.
nn
.
Conv2d
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
self
.
out_channels
h_out
=
math
.
floor
((
h_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
self
.
padding
[
1
]
-
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
-
1
)
/
self
.
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
macs_per_elem
=
reduce
(
operator
.
mul
,
self
.
kernel_size
)
*
c_in
//
self
.
groups
num_elem
=
reduce
(
operator
.
mul
,
result_shape
)
macs
=
macs_per_elem
*
num_elem
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
num_elem
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
Conv3d
)
def
torch_nn_conv3d
(
self
:
torch
.
nn
.
Conv3d
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in
,
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
4
:]
c_out
=
self
.
out_channels
d_out
=
math
.
floor
((
d_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
self
.
padding
[
1
]
-
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
-
1
)
/
self
.
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
self
.
padding
[
2
]
-
self
.
dilation
[
2
]
*
(
self
.
kernel_size
[
2
]
-
1
)
-
1
)
/
self
.
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
macs_per_elem
=
reduce
(
operator
.
mul
,
self
.
kernel_size
)
*
c_in
//
self
.
groups
num_elem
=
reduce
(
operator
.
mul
,
result_shape
)
macs
=
macs_per_elem
*
num_elem
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
num_elem
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
ConvTranspose1d
)
def
torch_nn_convtranspose1d
(
self
:
torch
.
nn
.
ConvTranspose1d
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in
,
l_in
=
input
.
shape
[
-
2
:]
c_out
=
self
.
out_channels
l_out
=
math
.
floor
((
l_in
-
1
)
*
self
.
stride
[
0
]
-
2
*
self
.
padding
[
0
]
+
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
+
self
.
output_padding
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
macs_per_elem
=
reduce
(
operator
.
mul
,
self
.
kernel_size
)
*
c_in
//
self
.
groups
num_elem
=
reduce
(
operator
.
mul
,
input
.
shape
)
# see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs
=
macs_per_elem
*
num_elem
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
reduce
(
operator
.
mul
,
result_shape
)
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
ConvTranspose2d
)
def
torch_nn_convtranspose2d
(
self
:
torch
.
nn
.
ConvTranspose2d
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
self
.
out_channels
h_out
=
math
.
floor
((
h_in
-
1
)
*
self
.
stride
[
0
]
-
2
*
self
.
padding
[
0
]
+
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
+
self
.
output_padding
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
self
.
stride
[
1
]
-
2
*
self
.
padding
[
1
]
+
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
+
self
.
output_padding
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
macs_per_elem
=
reduce
(
operator
.
mul
,
self
.
kernel_size
)
*
c_in
//
self
.
groups
num_elem
=
reduce
(
operator
.
mul
,
input
.
shape
)
macs
=
macs_per_elem
*
num_elem
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
reduce
(
operator
.
mul
,
result_shape
)
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
ConvTranspose3d
)
def
torch_nn_convtranspose3d
(
self
:
torch
.
nn
.
ConvTranspose3d
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in
,
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
4
:]
c_out
=
self
.
out_channels
d_out
=
math
.
floor
((
d_in
-
1
)
*
self
.
stride
[
0
]
-
2
*
self
.
padding
[
0
]
+
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
+
self
.
output_padding
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
-
1
)
*
self
.
stride
[
1
]
-
2
*
self
.
padding
[
1
]
+
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
+
self
.
output_padding
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
self
.
stride
[
2
]
-
2
*
self
.
padding
[
2
]
+
self
.
dilation
[
2
]
*
(
self
.
kernel_size
[
2
]
-
1
)
+
self
.
output_padding
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
macs_per_elem
=
reduce
(
operator
.
mul
,
self
.
kernel_size
)
*
c_in
//
self
.
groups
num_elem
=
reduce
(
operator
.
mul
,
input
.
shape
)
macs
=
macs_per_elem
*
num_elem
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
reduce
(
operator
.
mul
,
result_shape
)
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/dropout.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_module
@
meta_profiler_module
.
register
(
torch
.
nn
.
Dropout
)
def
torch_nn_dropout
(
self
:
torch
.
nn
.
Module
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops
=
0
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/embedding.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_module
@
meta_profiler_module
.
register
(
torch
.
nn
.
Embedding
)
def
torch_nn_embedding
(
self
:
torch
.
nn
.
Embedding
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops
=
0
macs
=
0
return
flops
,
macs
\ No newline at end of file
colossalai/fx/profiler/experimental/profiler_module/linear.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_module
@
meta_profiler_module
.
register
(
torch
.
nn
.
Linear
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
)
def
torch_nn_linear
(
self
:
torch
.
nn
.
Linear
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
out_features
=
self
.
weight
.
shape
[
0
]
macs
=
input
.
numel
()
*
out_features
flops
=
2
*
macs
if
self
.
bias
is
not
None
:
flops
+=
self
.
bias
.
numel
()
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/normalization.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
,
Union
import
torch
from
..registry
import
meta_profiler_module
@
meta_profiler_module
.
register
(
torch
.
nn
.
InstanceNorm1d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
InstanceNorm2d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
InstanceNorm3d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
LayerNorm
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
GroupNorm
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
BatchNorm3d
)
def
torch_nn_normalize
(
self
:
Union
[
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
BatchNorm1d
,
torch
.
nn
.
BatchNorm2d
,
torch
.
nn
.
BatchNorm3d
],
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine
=
self
.
weight
is
not
None
if
self
.
training
:
flops
=
input
.
numel
()
*
(
2
if
has_affine
else
1
)
else
:
flops
=
input
.
numel
()
*
(
5
if
has_affine
else
4
)
macs
=
0
return
flops
,
macs
try
:
import
apex
meta_profiler_module
.
register
(
apex
.
normalization
.
FusedLayerNorm
)(
torch_nn_normalize
)
meta_profiler_module
.
register
(
apex
.
normalization
.
FusedRMSNorm
)(
torch_nn_normalize
)
meta_profiler_module
.
register
(
apex
.
normalization
.
MixedFusedLayerNorm
)(
torch_nn_normalize
)
meta_profiler_module
.
register
(
apex
.
normalization
.
MixedFusedRMSNorm
)(
torch_nn_normalize
)
except
(
ImportError
,
AttributeError
):
pass
colossalai/fx/profiler/experimental/profiler_module/pooling.py
0 → 100644
View file @
08f2920e
from
typing
import
Tuple
import
torch
from
..registry
import
meta_profiler_module
@
meta_profiler_module
.
register
(
torch
.
nn
.
AvgPool1d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AvgPool2d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AvgPool3d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
MaxPool1d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
MaxPool2d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
MaxPool3d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool1d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool2d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool3d
)
def
torch_nn_pooling
(
self
:
torch
.
nn
.
Module
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops
=
input
.
numel
()
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/rnn.py
0 → 100644
View file @
08f2920e
from
functools
import
reduce
import
operator
import
torch
from
..registry
import
meta_profiler_module
from
typing
import
Optional
,
Tuple
,
Union
def
_rnn_flops
(
flops
:
int
,
macs
:
int
,
module
:
torch
.
nn
.
RNNBase
,
w_ih
:
torch
.
Tensor
,
w_hh
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
macs
+=
reduce
(
operator
.
mul
,
w_ih
.
shape
)
flops
+=
2
*
reduce
(
operator
.
mul
,
w_ih
.
shape
)
# matrix matrix mult hh state and internal state
macs
+=
reduce
(
operator
.
mul
,
w_hh
.
shape
)
flops
+=
2
*
reduce
(
operator
.
mul
,
w_hh
.
shape
)
if
isinstance
(
module
,
(
torch
.
nn
.
RNN
,
torch
.
nn
.
RNNCell
)):
# add both operations
flops
+=
module
.
hidden_size
elif
isinstance
(
module
,
(
torch
.
nn
.
GRU
,
torch
.
nn
.
GRUCell
)):
# hadamard of r
flops
+=
module
.
hidden_size
# adding operations from both states
flops
+=
module
.
hidden_size
*
3
# last two hadamard product and add
flops
+=
module
.
hidden_size
*
3
elif
isinstance
(
module
,
(
torch
.
nn
.
LSTM
,
torch
.
nn
.
LSTMCell
)):
# adding operations from both states
flops
+=
module
.
hidden_size
*
4
# two hadamard product and add for C state
flops
+=
module
.
hidden_size
*
3
# final hadamard
flops
+=
module
.
hidden_size
*
3
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
LSTM
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
GRU
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
RNN
)
def
torch_nn_rnn
(
self
:
torch
.
nn
.
RNNBase
,
input
:
torch
.
Tensor
,
hx
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
int
,
int
]:
flops
=
0
macs
=
0
for
i
in
range
(
self
.
num_layers
):
w_ih
=
self
.
__getattr__
(
'weight_ih_l'
+
str
(
i
))
w_hh
=
self
.
__getattr__
(
'weight_hh_l'
+
str
(
i
))
flops
,
macs
=
_rnn_flops
(
flops
,
macs
,
self
,
w_ih
,
w_hh
)
if
self
.
bias
:
b_ih
=
self
.
__getattr__
(
'bias_ih_l'
+
str
(
i
))
b_hh
=
self
.
__getattr__
(
'bias_hh_l'
+
str
(
i
))
flops
+=
reduce
(
operator
.
mul
,
b_ih
)
+
reduce
(
operator
.
mul
,
b_hh
)
flops
*=
reduce
(
operator
.
mul
,
input
.
shape
[:
2
])
macs
*=
reduce
(
operator
.
mul
,
input
.
shape
[:
2
])
if
self
.
bidirectional
:
flops
*=
2
macs
*=
2
return
flops
,
macs
@
meta_profiler_module
.
register
(
torch
.
nn
.
LSTMCell
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
GRUCell
)
@
meta_profiler_module
.
register
(
torch
.
nn
.
RNNCell
)
def
torch_nn_rnn
(
self
:
torch
.
nn
.
RNNCellBase
,
input
:
torch
.
Tensor
,
hx
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
int
,
int
]:
flops
=
0
macs
=
0
w_ih
=
self
.
__getattr__
(
'weight_ih_l'
)
w_hh
=
self
.
__getattr__
(
'weight_hh_l'
)
flops
,
macs
=
_rnn_flops
(
flops
,
macs
,
self
,
w_ih
,
w_hh
)
if
self
.
bias
:
b_ih
=
self
.
__getattr__
(
'bias_ih_l'
)
b_hh
=
self
.
__getattr__
(
'bias_hh_l'
)
flops
+=
reduce
(
operator
.
mul
,
b_ih
)
+
reduce
(
operator
.
mul
,
b_hh
)
flops
*=
input
.
shape
[
0
]
macs
*=
input
.
shape
[
0
]
return
flops
,
macs
colossalai/fx/profiler/experimental/profiler_module/torch_op.py
0 → 100644
View file @
08f2920e
import
operator
import
torch
from
..registry
import
meta_profiler_module
from
typing
import
Optional
,
Tuple
,
Union
@
meta_profiler_module
.
register
(
torch
.
nn
.
Flatten
)
def
torch_nn_flatten
(
self
:
torch
.
nn
.
Flatten
,
input
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
flops
=
0
macs
=
0
return
flops
,
macs
colossalai/fx/profiler/experimental/registry.py
0 → 100644
View file @
08f2920e
class
ProfilerRegistry
:
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
store
=
{}
def
register
(
self
,
source
):
def
wrapper
(
func
):
self
.
store
[
source
]
=
func
return
func
return
wrapper
def
get
(
self
,
source
):
assert
source
in
self
.
store
target
=
self
.
store
[
source
]
return
target
def
has
(
self
,
source
):
return
source
in
self
.
store
meta_profiler_function
=
ProfilerRegistry
(
name
=
'patched_functions_for_meta_profile'
)
meta_profiler_module
=
ProfilerRegistry
(
name
=
'patched_modules_for_meta_profile'
)
colossalai/fx/profiler/experimental/shard_utils.py
0 → 100644
View file @
08f2920e
# for PyTorch 1.11 compatibility uses
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
from
torch.fx
import
GraphModule
,
Node
from
..._compatibility
import
compatibility
__all__
=
[
"calculate_fwd_in"
,
"calculate_fwd_tmp"
,
"calculate_fwd_out"
]
@
compatibility
(
is_backward_compatible
=
True
)
def
calculate_fwd_in
(
n
:
Node
)
->
bool
:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return
n
.
meta
[
'save_fwd_in'
]
@
compatibility
(
is_backward_compatible
=
True
)
def
calculate_fwd_tmp
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_tmp`
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return
n
.
meta
[
"fwd_mem_tmp"
]
@
compatibility
(
is_backward_compatible
=
True
)
def
calculate_fwd_out
(
n
:
Node
)
->
int
:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
return
n
.
meta
[
'fwd_mem_out'
]
colossalai/fx/profiler/memory_utils.py
0 → 100644
View file @
08f2920e
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
from
torch.fx
import
GraphModule
,
Node
from
.._compatibility
import
compatibility
,
is_compatible_with_meta
__all__
=
[
'activation_size'
,
'parameter_size'
,
'is_inplace'
]
@
compatibility
(
is_backward_compatible
=
True
)
def
activation_size
(
out
:
Union
[
torch
.
Tensor
,
Dict
,
List
,
Tuple
,
int
])
->
int
:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
Returns:
int: The activation size, unit is byte.
"""
act_size
=
0
if
isinstance
(
out
,
torch
.
Tensor
):
if
out
.
is_quantized
:
act_size
+=
out
.
numel
()
*
torch
.
_empty_affine_quantized
([],
dtype
=
out
.
dtype
).
element_size
()
else
:
act_size
+=
out
.
numel
()
*
torch
.
tensor
([],
dtype
=
out
.
dtype
).
element_size
()
elif
isinstance
(
out
,
dict
):
value_list
=
[
v
for
_
,
v
in
out
.
items
()]
act_size
+=
activation_size
(
value_list
)
elif
isinstance
(
out
,
tuple
)
or
isinstance
(
out
,
list
)
or
isinstance
(
out
,
set
):
for
element
in
out
:
act_size
+=
activation_size
(
element
)
return
act_size
@
compatibility
(
is_backward_compatible
=
True
)
def
parameter_size
(
mod
:
torch
.
nn
.
Module
)
->
int
:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`.
Returns:
int: The parameter size, unit is byte.
"""
param_size
=
0
for
param
in
mod
.
parameters
():
param_size
+=
param
.
numel
()
*
torch
.
tensor
([],
dtype
=
param
.
dtype
).
element_size
()
return
param_size
def
is_inplace
(
n
:
Node
):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
if
is_compatible_with_meta
():
from
.constants
import
ALIAS_ATEN
if
n
.
target
in
ALIAS_ATEN
:
inplace
=
True
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
colossalai/fx/profiler/opcount.py
0 → 100644
View file @
08f2920e
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
import
operator
from
functools
import
partial
,
reduce
from
numbers
import
Number
from
typing
import
Any
,
Callable
,
List
import
torch
aten
=
torch
.
ops
.
aten
def
matmul_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes
=
[
v
.
shape
for
v
in
inputs
]
assert
len
(
input_shapes
)
==
2
,
input_shapes
assert
input_shapes
[
0
][
-
1
]
==
input_shapes
[
1
][
-
2
],
input_shapes
flops
=
reduce
(
operator
.
mul
,
input_shapes
[
0
])
*
input_shapes
[
-
1
][
-
1
]
return
flops
def
addmm_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes
=
[
v
.
shape
for
v
in
inputs
[
1
:
3
]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [input feature dimension, output feature dimension]
assert
len
(
input_shapes
[
0
])
==
2
,
input_shapes
[
0
]
assert
len
(
input_shapes
[
1
])
==
2
,
input_shapes
[
1
]
batch_size
,
input_dim
=
input_shapes
[
0
]
output_dim
=
input_shapes
[
1
][
1
]
flops
=
batch_size
*
input_dim
*
output_dim
return
flops
def
linear_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
"""
Count flops for the aten::linear operator.
"""
# Inputs is a list of length 3; unlike aten::addmm, it is the first
# two elements that are relevant.
input_shapes
=
[
v
.
shape
for
v
in
inputs
[
0
:
2
]]
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
# input_shapes[1]: [output_feature_dim, input_feature_dim]
assert
input_shapes
[
0
][
-
1
]
==
input_shapes
[
1
][
-
1
]
flops
=
reduce
(
operator
.
mul
,
input_shapes
[
0
])
*
input_shapes
[
1
][
0
]
return
flops
def
bmm_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert
len
(
inputs
)
==
2
,
len
(
inputs
)
input_shapes
=
[
v
.
shape
for
v
in
inputs
]
n
,
c
,
t
=
input_shapes
[
0
]
d
=
input_shapes
[
-
1
][
-
1
]
flops
=
n
*
c
*
t
*
d
return
flops
def
conv_flop_count
(
x_shape
:
List
[
int
],
w_shape
:
List
[
int
],
out_shape
:
List
[
int
],
transposed
:
bool
=
False
,
)
->
Number
:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size
=
x_shape
[
0
]
conv_shape
=
(
x_shape
if
transposed
else
out_shape
)[
2
:]
flops
=
batch_size
*
reduce
(
operator
.
mul
,
w_shape
)
*
reduce
(
operator
.
mul
,
conv_shape
)
return
flops
def
conv_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
]):
"""
Count flops for convolution.
"""
x
,
w
=
inputs
[:
2
]
x_shape
,
w_shape
,
out_shape
=
(
x
.
shape
,
w
.
shape
,
outputs
[
0
].
shape
)
transposed
=
inputs
[
6
]
return
conv_flop_count
(
x_shape
,
w_shape
,
out_shape
,
transposed
=
transposed
)
def
transpose_shape
(
shape
):
return
[
shape
[
1
],
shape
[
0
]]
+
list
(
shape
[
2
:])
def
conv_backward_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
]):
grad_out_shape
,
x_shape
,
w_shape
=
[
i
.
shape
for
i
in
inputs
[:
3
]]
output_mask
=
inputs
[
-
1
]
fwd_transposed
=
inputs
[
7
]
flop_count
=
0
if
output_mask
[
0
]:
grad_input_shape
=
outputs
[
0
].
shape
flop_count
+=
conv_flop_count
(
grad_out_shape
,
w_shape
,
grad_input_shape
,
not
fwd_transposed
)
if
output_mask
[
1
]:
grad_weight_shape
=
outputs
[
1
].
shape
flop_count
+=
conv_flop_count
(
transpose_shape
(
x_shape
),
grad_out_shape
,
grad_weight_shape
,
fwd_transposed
)
return
flop_count
def
norm_flop_counter
(
affine_arg_index
:
int
,
input_arg_index
:
int
)
->
Callable
:
"""
Args:
affine_arg_index: index of the affine argument in inputs
"""
def
norm_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
"""
Count flops for norm layers.
"""
# Inputs[0] contains the shape of the input.
input_shape
=
inputs
[
input_arg_index
].
shape
has_affine
=
inputs
[
affine_arg_index
].
shape
is
not
None
if
hasattr
(
inputs
[
affine_arg_index
],
'shape'
)
else
inputs
[
affine_arg_index
]
assert
2
<=
len
(
input_shape
)
<=
5
,
input_shape
# 5 is just a rough estimate
flop
=
reduce
(
operator
.
mul
,
input_shape
)
*
(
5
if
has_affine
else
4
)
return
flop
return
norm_flop_jit
def
batchnorm_flop_jit
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
],
training
:
bool
=
None
)
->
Number
:
if
training
is
None
:
training
=
inputs
[
-
3
]
assert
isinstance
(
training
,
bool
),
"Signature of aten::batch_norm has changed!"
if
training
:
return
norm_flop_counter
(
1
,
0
)(
inputs
,
outputs
)
# pyre-ignore
has_affine
=
inputs
[
1
].
shape
is
not
None
input_shape
=
reduce
(
operator
.
mul
,
inputs
[
0
].
shape
)
return
input_shape
*
(
2
if
has_affine
else
1
)
def
elementwise_flop_counter
(
input_scale
:
float
=
1
,
output_scale
:
float
=
0
)
->
Callable
:
"""
Count flops by
input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
Args:
input_scale: scale of the input tensor (first argument)
output_scale: scale of the output tensor (first element in outputs)
"""
def
elementwise_flop
(
inputs
:
List
[
Any
],
outputs
:
List
[
Any
])
->
Number
:
ret
=
0
if
input_scale
!=
0
:
shape
=
inputs
[
0
].
shape
ret
+=
input_scale
*
reduce
(
operator
.
mul
,
shape
)
if
shape
else
0
if
output_scale
!=
0
:
shape
=
outputs
[
0
].
shape
ret
+=
output_scale
*
reduce
(
operator
.
mul
,
shape
)
if
shape
else
0
return
ret
return
elementwise_flop
def
zero_flop_jit
(
*
args
):
"""
Count flops for zero flop layers.
"""
return
0
flop_mapping
=
{
# gemm
aten
.
mm
.
default
:
matmul_flop_jit
,
aten
.
matmul
.
default
:
matmul_flop_jit
,
aten
.
addmm
.
default
:
addmm_flop_jit
,
aten
.
bmm
.
default
:
bmm_flop_jit
,
# convolution
aten
.
convolution
.
default
:
conv_flop_jit
,
aten
.
_convolution
.
default
:
conv_flop_jit
,
aten
.
convolution_backward
.
default
:
conv_backward_flop_jit
,
# normalization
aten
.
native_batch_norm
.
default
:
batchnorm_flop_jit
,
aten
.
native_batch_norm_backward
.
default
:
batchnorm_flop_jit
,
aten
.
cudnn_batch_norm
.
default
:
batchnorm_flop_jit
,
aten
.
cudnn_batch_norm_backward
.
default
:
partial
(
batchnorm_flop_jit
,
training
=
True
),
aten
.
native_layer_norm
.
default
:
norm_flop_counter
(
2
,
0
),
aten
.
native_layer_norm_backward
.
default
:
norm_flop_counter
(
2
,
0
),
# pooling
aten
.
avg_pool1d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
avg_pool2d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
avg_pool2d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
avg_pool3d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
max_pool1d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
max_pool2d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
max_pool3d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
max_pool1d_with_indices
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
max_pool2d_with_indices
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
max_pool2d_with_indices_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
max_pool3d_with_indices
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
max_pool3d_with_indices_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool2d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
_adaptive_avg_pool2d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool3d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding
.
default
:
elementwise_flop_counter
(
1
,
0
),
}
elementwise_flop_aten
=
[
# basic op
aten
.
add
.
Tensor
,
aten
.
add_
.
Tensor
,
aten
.
div
.
Tensor
,
aten
.
div_
.
Tensor
,
aten
.
div
.
Scalar
,
aten
.
div_
.
Scalar
,
aten
.
mul
.
Tensor
,
aten
.
mul
.
Scalar
,
aten
.
mul_
.
Tensor
,
aten
.
neg
.
default
,
aten
.
pow
.
Tensor_Scalar
,
aten
.
rsub
.
Scalar
,
aten
.
sum
.
default
,
aten
.
sum
.
dim_IntList
,
aten
.
mean
.
dim
,
# activation op
aten
.
hardswish
.
default
,
aten
.
hardswish_
.
default
,
aten
.
hardswish_backward
.
default
,
aten
.
hardtanh
.
default
,
aten
.
hardtanh_
.
default
,
aten
.
hardtanh_backward
.
default
,
aten
.
hardsigmoid_backward
.
default
,
aten
.
hardsigmoid
.
default
,
aten
.
gelu
.
default
,
aten
.
gelu_backward
.
default
,
aten
.
silu
.
default
,
aten
.
silu_
.
default
,
aten
.
silu_backward
.
default
,
aten
.
sigmoid
.
default
,
aten
.
sigmoid_backward
.
default
,
aten
.
_softmax
.
default
,
aten
.
_softmax_backward_data
.
default
,
aten
.
relu_
.
default
,
aten
.
relu
.
default
,
aten
.
tanh
.
default
,
aten
.
tanh_backward
.
default
,
aten
.
threshold_backward
.
default
,
# dropout
aten
.
native_dropout
.
default
,
aten
.
native_dropout_backward
.
default
,
]
for
op
in
elementwise_flop_aten
:
flop_mapping
[
op
]
=
elementwise_flop_counter
(
1
,
0
)
# TODO: this will be removed in future
zero_flop_aten
=
[
aten
.
as_strided
.
default
,
aten
.
as_strided_
.
default
,
aten
.
bernoulli_
.
float
,
aten
.
cat
.
default
,
aten
.
clone
.
default
,
aten
.
copy_
.
default
,
aten
.
detach
.
default
,
aten
.
expand
.
default
,
aten
.
empty_like
.
default
,
aten
.
new_empty
.
default
,
aten
.
new_empty_strided
.
default
,
aten
.
ones_like
.
default
,
aten
.
_reshape_alias
.
default
,
aten
.
select
.
int
,
aten
.
select_backward
.
default
,
aten
.
squeeze
.
dim
,
aten
.
slice
.
Tensor
,
aten
.
slice_backward
.
default
,
aten
.
split
.
Tensor
,
aten
.
permute
.
default
,
aten
.
t
.
default
,
aten
.
transpose
.
int
,
aten
.
_to_copy
.
default
,
aten
.
unsqueeze
.
default
,
aten
.
unbind
.
int
,
aten
.
_unsafe_view
.
default
,
aten
.
view
.
default
,
aten
.
where
.
self
,
aten
.
zero_
.
default
,
aten
.
zeros_like
.
default
,
]
for
op
in
zero_flop_aten
:
flop_mapping
[
op
]
=
zero_flop_jit
Prev
1
…
10
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment