Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
08f2920e
Commit
08f2920e
authored
Apr 23, 2023
by
zhuwenwen
Browse files
init colossalai, support dtk2304
parent
da3f0934
Pipeline
#237
failed with stages
in 0 seconds
Changes
380
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1537 additions
and
0 deletions
+1537
-0
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
+6
-0
colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
...tracer/meta_patch/patched_function/activation_function.py
+8
-0
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+95
-0
colossalai/fx/tracer/meta_patch/patched_function/convolution.py
...alai/fx/tracer/meta_patch/patched_function/convolution.py
+180
-0
colossalai/fx/tracer/meta_patch/patched_function/embedding.py
...ssalai/fx/tracer/meta_patch/patched_function/embedding.py
+14
-0
colossalai/fx/tracer/meta_patch/patched_function/normalization.py
...ai/fx/tracer/meta_patch/patched_function/normalization.py
+20
-0
colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
...salai/fx/tracer/meta_patch/patched_function/python_ops.py
+60
-0
colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
...ssalai/fx/tracer/meta_patch/patched_function/torch_ops.py
+174
-0
colossalai/fx/tracer/meta_patch/patched_module/__init__.py
colossalai/fx/tracer/meta_patch/patched_module/__init__.py
+7
-0
colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
...x/tracer/meta_patch/patched_module/activation_function.py
+13
-0
colossalai/fx/tracer/meta_patch/patched_module/convolution.py
...ssalai/fx/tracer/meta_patch/patched_module/convolution.py
+113
-0
colossalai/fx/tracer/meta_patch/patched_module/embedding.py
colossalai/fx/tracer/meta_patch/patched_module/embedding.py
+9
-0
colossalai/fx/tracer/meta_patch/patched_module/linear.py
colossalai/fx/tracer/meta_patch/patched_module/linear.py
+10
-0
colossalai/fx/tracer/meta_patch/patched_module/normalization.py
...alai/fx/tracer/meta_patch/patched_module/normalization.py
+31
-0
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
+202
-0
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
+16
-0
colossalai/fx/tracer/registry.py
colossalai/fx/tracer/registry.py
+28
-0
colossalai/fx/tracer/tracer.py
colossalai/fx/tracer/tracer.py
+536
-0
colossalai/gemini/__init__.py
colossalai/gemini/__init__.py
+9
-0
colossalai/gemini/chunk/__init__.py
colossalai/gemini/chunk/__init__.py
+6
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
colossalai/fx/tracer/meta_patch/patched_function/__init__.py
0 → 100644
View file @
08f2920e
from
.activation_function
import
*
from
.arithmetic
import
*
from
.convolution
import
*
from
.embedding
import
*
from
.normalization
import
*
from
.torch_ops
import
*
colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
relu
)
def
torch_nn_func_relu
(
input
,
inplace
=
False
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
matmul
)
@
meta_patched_function
.
register
(
'matmul'
)
# for built-in op @
def
torch_matmul
(
input
,
other
,
*
,
out
=
None
):
# copied from huggingface.utils.fx
d1
=
input
.
dim
()
d2
=
other
.
dim
()
shape
=
None
if
d1
==
1
and
d2
==
1
:
shape
=
None
elif
d1
==
2
and
d2
==
2
:
shape
=
(
input
.
size
(
0
),
other
.
size
(
1
))
elif
d1
==
1
and
d2
==
2
:
shape
=
(
other
.
size
(
1
),)
elif
d1
==
2
and
d2
==
1
:
shape
=
(
input
.
size
(
0
),)
else
:
max_length
=
max
(
input
.
dim
(),
other
.
dim
())
shape1
=
list
(
input
.
shape
)
shape2
=
list
(
other
.
shape
)
if
d1
==
1
:
shape1
=
[
1
]
+
shape1
if
d2
==
1
:
shape2
.
append
(
1
)
shape1
=
[
-
1
]
*
(
max_length
-
d1
)
+
list
(
input
.
shape
)
shape2
=
[
-
1
]
*
(
max_length
-
d2
)
+
list
(
other
.
shape
)
shape
=
[]
for
i
in
range
(
max_length
):
shape
.
append
(
max
(
shape1
[
i
],
shape2
[
i
]))
shape
[
-
2
]
=
shape1
[
-
2
]
shape
[
-
1
]
=
shape2
[
-
1
]
if
d1
==
1
:
shape
.
pop
(
-
2
)
if
d2
==
1
:
shape
.
pop
(
-
1
)
if
shape
is
None
:
return
torch
.
tensor
(
0.0
,
device
=
"meta"
)
return
torch
.
empty
(
*
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
abs
)
def
torch_abs
(
input
,
*
,
out
=
None
):
assert
out
is
None
,
'out is not supported yet'
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
bmm
)
def
torch_bmm
(
input
,
mat2
,
*
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
batch_size
,
n
,
m
=
input
.
shape
_
,
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
batch_size
,
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
linear
)
def
torch_linear
(
input
,
mat2
,
bias
=
None
,
*
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
output_shape
=
list
(
input
.
shape
)
output_feature
=
list
(
mat2
.
shape
)[
0
]
output_shape
[
-
1
]
=
output_feature
return
torch
.
empty
(
*
output_shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
addbmm
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
addbmm
)
def
torch_addbmm
(
input
,
mat1
,
mat2
,
*
,
beta
=
1
,
alpha
=
1
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
_
,
n
,
_
=
mat1
.
shape
_
,
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
addmm
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
addmm
)
def
torch_addmm
(
input
,
mat1
,
mat2
,
*
,
beta
=
1
,
alpha
=
1
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
n
,
_
=
mat1
.
shape
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
var_mean
)
def
torch_var_mean
(
input
,
dim
,
unbiased
=
True
,
keepdim
=
False
,
*
,
out
=
None
):
assert
out
is
None
,
'saving to out is not supported yet'
var
=
torch
.
empty
(
1
).
squeeze
(
0
).
to
(
'meta'
)
mean
=
torch
.
empty
(
1
).
squeeze
(
0
).
to
(
'meta'
)
return
var
,
mean
colossalai/fx/tracer/meta_patch/patched_function/convolution.py
0 → 100644
View file @
08f2920e
import
collections
import
math
from
itertools
import
repeat
import
torch
from
...registry
import
meta_patched_function
def
_ntuple
(
n
,
name
=
"parse"
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
tuple
(
x
)
return
tuple
(
repeat
(
x
,
n
))
parse
.
__name__
=
name
return
parse
_single
=
_ntuple
(
1
,
"_single"
)
_pair
=
_ntuple
(
2
,
"_pair"
)
_triple
=
_ntuple
(
3
,
"_triple"
)
def
_extract_kwargs
(
kwargs
):
if
'stride'
in
kwargs
:
stride
=
kwargs
[
'stride'
]
else
:
stride
=
1
# TODO: process str type padding
if
'padding'
in
kwargs
:
padding
=
kwargs
[
'padding'
]
else
:
padding
=
0
if
'dilation'
in
kwargs
:
dilation
=
kwargs
[
'dilation'
]
else
:
dilation
=
1
if
'output_padding'
in
kwargs
:
output_padding
=
kwargs
[
'output_padding'
]
else
:
output_padding
=
0
return
stride
,
padding
,
dilation
,
output_padding
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv1d
)
def
torch_nn_functional_conv1d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
_
=
_extract_kwargs
(
kwargs
)
stride
=
_single
(
stride
)
padding
=
_single
(
padding
)
dilation
=
_single
(
dilation
)
kernel_size
=
weight
.
shape
[
2
:]
l_in
=
input
.
shape
[
-
1
]
c_out
=
weight
.
shape
[
0
]
l_out
=
math
.
floor
((
l_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv2d
)
def
torch_nn_functional_conv2d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
_
=
_extract_kwargs
(
kwargs
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
dilation
=
_pair
(
dilation
)
kernel_size
=
weight
.
shape
[
2
:]
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
weight
.
shape
[
0
]
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv3d
)
def
torch_nn_functional_conv3d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
_
=
_extract_kwargs
(
kwargs
)
stride
=
_triple
(
stride
)
padding
=
_triple
(
padding
)
dilation
=
_triple
(
dilation
)
kernel_size
=
weight
.
shape
[
2
:]
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
weight
.
shape
[
0
]
d_out
=
math
.
floor
((
d_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
-
1
)
/
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv_transpose1d
)
def
torch_nn_functional_convtranspose1d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
output_padding
=
_extract_kwargs
(
kwargs
)
stride
=
_single
(
stride
)
padding
=
_single
(
padding
)
dilation
=
_single
(
dilation
)
output_padding
=
_single
(
output_padding
)
kernel_size
=
weight
.
shape
[
2
:]
l_in
=
input
.
shape
[
-
1
]
c_out
=
weight
.
shape
[
1
]
l_out
=
math
.
floor
((
l_in
-
1
)
*
stride
[
0
]
-
2
*
padding
[
0
]
+
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
+
output_padding
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv_transpose2d
)
def
torch_nn_functional_convtranspose2d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
output_padding
=
_extract_kwargs
(
kwargs
)
stride
=
_pair
(
stride
)
padding
=
_pair
(
padding
)
dilation
=
_pair
(
dilation
)
output_padding
=
_pair
(
output_padding
)
kernel_size
=
weight
.
shape
[
2
:]
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
weight
.
shape
[
1
]
h_out
=
math
.
floor
((
h_in
-
1
)
*
stride
[
0
]
-
2
*
padding
[
0
]
+
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
+
output_padding
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
stride
[
1
]
-
2
*
padding
[
1
]
+
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
+
output_padding
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
conv_transpose3d
)
def
torch_nn_functional_convtranspose3d
(
input
,
weight
,
**
kwargs
):
stride
,
padding
,
dilation
,
output_padding
=
_extract_kwargs
(
kwargs
)
stride
=
_triple
(
stride
)
padding
=
_triple
(
padding
)
dilation
=
_triple
(
dilation
)
output_padding
=
_triple
(
output_padding
)
kernel_size
=
weight
.
shape
[
2
:]
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
weight
.
shape
[
1
]
d_out
=
math
.
floor
((
d_in
-
1
)
*
stride
[
0
]
-
2
*
padding
[
0
]
+
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
+
output_padding
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
-
1
)
*
stride
[
1
]
-
2
*
padding
[
1
]
+
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
+
output_padding
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
stride
[
2
]
-
2
*
padding
[
2
]
+
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
+
output_padding
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_function/embedding.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
embedding
)
def
torch_nn_functional_embedding
(
input
,
weight
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
):
return
torch
.
empty
(
*
input
.
shape
,
weight
.
shape
[
-
1
],
device
=
"meta"
)
colossalai/fx/tracer/meta_patch/patched_function/normalization.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
layer_norm
)
def
torch_nn_func_layernorm
(
input
,
normalized_shape
,
weight
=
None
,
bias
=
None
,
eps
=
1e-05
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
batch_norm
)
def
torch_nn_func_batchnorm
(
input
,
running_mean
,
running_var
,
weight
=
None
,
bias
=
None
,
training
=
False
,
momentum
=
0.1
,
eps
=
1e-05
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
0 → 100644
View file @
08f2920e
import
operator
import
torch
from
colossalai.fx.proxy
import
ColoProxy
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
operator
.
getitem
)
def
operator_getitem
(
a
,
b
):
# copied from huggingface.utils.fx
def
to_concrete
(
t
):
if
isinstance
(
t
,
torch
.
Tensor
):
concrete
=
torch
.
ones_like
(
t
,
device
=
"cpu"
)
if
concrete
.
dtype
in
[
torch
.
float16
,
torch
.
float32
,
torch
.
float64
,
torch
.
int32
]:
concrete
=
concrete
.
to
(
torch
.
int64
)
return
concrete
return
t
def
_slice_convert
(
slice_obj
):
attrs
=
{
'start'
:
slice_obj
.
start
,
'stop'
:
slice_obj
.
stop
,
'step'
:
slice_obj
.
step
}
new_attrs
=
_slice_attr_convert
(
attrs
)
attr_dict_to_tuple
=
(
new_attrs
[
'start'
],
new_attrs
[
'stop'
],
new_attrs
[
'step'
])
return
slice
(
*
attr_dict_to_tuple
)
def
_slice_attr_convert
(
attrs
):
new_attrs
=
{}
for
key
,
value
in
attrs
.
items
():
if
isinstance
(
value
,
ColoProxy
):
new_attrs
[
key
]
=
value
.
meta_data
else
:
new_attrs
[
key
]
=
value
return
new_attrs
if
isinstance
(
b
,
tuple
):
b
=
list
(
b
)
for
index
,
element
in
enumerate
(
b
):
if
isinstance
(
element
,
slice
):
b
[
index
]
=
_slice_convert
(
element
)
b
=
tuple
(
b
)
elif
isinstance
(
b
,
slice
):
b
=
_slice_convert
(
b
)
if
isinstance
(
a
,
torch
.
Tensor
):
# TODO: infer shape without performing the computation.
if
isinstance
(
b
,
tuple
):
b
=
tuple
(
map
(
to_concrete
,
b
))
else
:
b
=
to_concrete
(
b
)
return
operator
.
getitem
(
torch
.
empty_like
(
a
,
device
=
"cpu"
),
b
).
to
(
"meta"
)
if
isinstance
(
a
,
ColoProxy
):
# TODO: infer shape without performing the computation.
if
isinstance
(
b
,
tuple
):
b
=
tuple
(
map
(
to_concrete
,
b
))
else
:
b
=
to_concrete
(
b
)
return
operator
.
getitem
(
torch
.
empty_like
(
a
.
meta_data
,
device
=
"cpu"
),
b
).
to
(
"meta"
)
return
operator
.
getitem
(
a
,
b
)
colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_function
@
meta_patched_function
.
register
(
torch
.
arange
)
def
torch_arange
(
*
args
,
**
kwargs
):
n
=
len
(
args
)
step
=
1
if
n
==
1
:
start
=
0
end
=
args
[
0
]
elif
n
==
2
:
start
,
end
=
args
else
:
start
,
end
,
step
=
args
if
isinstance
(
start
,
float
):
start
=
int
(
start
)
if
isinstance
(
end
,
float
):
start
=
int
(
end
)
if
isinstance
(
step
,
float
):
step
=
int
(
step
)
step
=
kwargs
.
get
(
"step"
,
step
)
dtype
=
kwargs
.
get
(
"dtype"
)
return
torch
.
empty
((
end
-
start
)
//
step
,
dtype
=
dtype
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
finfo
)
def
torch_finfo
(
*
args
):
return
torch
.
finfo
(
*
args
)
@
meta_patched_function
.
register
(
torch
.
where
)
def
torch_where
(
condition
,
x
,
y
):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return
condition
.
to
(
device
=
"meta"
)
+
x
.
to
(
device
=
"meta"
)
+
y
.
to
(
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
repeat
)
def
torch_tensor_repeat
(
self
,
*
sizes
):
shape
=
list
(
self
.
shape
)
for
i
,
x
in
enumerate
(
sizes
):
shape
[
i
]
*=
x
return
torch
.
empty
(
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
index_select
)
def
torch_index_select
(
input
,
dim
,
index
,
*
,
out
=
None
):
shape
=
list
(
input
.
shape
)
shape
[
dim
]
=
len
(
index
)
return
torch
.
empty
(
*
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
index_select
)
def
torch_tensor_index_select
(
self
,
dim
,
index
):
return
torch_index_select
(
self
,
dim
,
index
)
@
meta_patched_function
.
register
(
torch
.
squeeze
)
def
torch_squeeze
(
input
,
dim
=
None
):
shape
=
list
(
input
.
shape
)
if
dim
is
not
None
:
if
dim
<
0
:
dim
=
input
.
dim
()
+
dim
if
shape
[
dim
]
==
1
:
shape
.
pop
(
dim
)
else
:
new_shape
=
[]
for
dim_value
in
shape
:
if
dim_value
==
1
:
continue
new_shape
.
append
(
dim_value
)
shape
=
new_shape
return
torch
.
empty
(
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
squeeze
)
def
torch_tensor_squeeze
(
self
,
dim
=
None
):
return
torch_squeeze
(
self
,
dim
)
@
meta_patched_function
.
register
(
torch
.
unsqueeze
)
def
torch_unsqueeze
(
input
,
dim
):
shape
=
list
(
input
.
shape
)
if
dim
<
0
:
dim
=
input
.
dim
()
+
1
+
dim
shape
.
insert
(
dim
,
1
)
return
torch
.
empty
(
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
unsqueeze
)
def
torch_tensor_unsqueeze
(
self
,
dim
):
return
torch_unsqueeze
(
self
,
dim
)
@
meta_patched_function
.
register
(
torch
.
cat
)
def
torch_cat
(
tensors
,
dim
=
None
,
axis
=
None
,
*
,
out
=
None
):
if
dim
is
None
and
axis
is
None
:
dim
=
0
if
dim
is
None
and
axis
is
not
None
:
dim
=
axis
if
dim
<
0
:
dim
=
tensors
[
0
].
dim
()
+
dim
shapes
=
[
t
.
shape
for
t
in
tensors
]
shape
=
list
(
shapes
[
0
])
concatenated_dim
=
sum
(
shape
[
dim
]
for
shape
in
shapes
)
final_shape
=
shape
[:
dim
]
+
[
concatenated_dim
]
+
shape
[
dim
+
1
:]
return
torch
.
empty
(
final_shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
repeat_interleave
)
def
torch_repeat_interleave
(
input
,
repeats
,
dim
=
None
,
output_size
=
None
):
assert
isinstance
(
repeats
,
int
)
or
isinstance
(
repeats
,
torch
.
Tensor
),
\
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape
=
list
(
input
.
shape
)
if
dim
is
not
None
else
[
input
.
numel
()]
dim
=
dim
if
dim
is
not
None
else
0
dim
=
input
.
dim
()
+
dim
if
dim
<
0
else
dim
if
isinstance
(
repeats
,
int
):
shape
[
dim
]
=
shape
[
dim
]
*
repeats
elif
isinstance
(
repeats
,
torch
.
Tensor
):
shape
[
dim
]
=
repeats
.
sum
()
return
torch
.
empty
(
shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
repeat_interleave
)
def
torch_tensor_repeat_interleave
(
self
,
repeats
,
dim
=
None
,
*
,
output_size
=
None
):
return
torch_repeat_interleave
(
self
,
repeats
,
dim
,
output_size
)
@
meta_patched_function
.
register
(
torch
.
roll
)
def
torch_roll
(
input
,
shifts
,
dims
=
None
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
@
meta_patched_function
.
register
(
torch
.
full
)
def
torch_full
(
size
,
fill_value
,
*
,
out
=
None
,
dtype
=
None
,
layout
=
torch
.
strided
,
device
=
None
,
requires_grad
=
False
):
assert
out
is
None
,
'assigning result to out is not supported yet'
return
torch
.
empty
(
size
,
device
=
'meta'
,
dtype
=
dtype
,
layout
=
layout
,
requires_grad
=
requires_grad
)
@
meta_patched_function
.
register
(
torch
.
max
)
def
torch_max
(
input
,
dim
=
None
,
keepdim
=
False
,
*
,
out
=
None
):
assert
out
is
None
,
'assigning value to out is not supported yet'
if
dim
is
not
None
:
if
isinstance
(
dim
,
int
):
shape
=
list
(
input
.
shape
)
shape
.
pop
(
dim
)
if
keepdim
:
shape
.
insert
(
dim
,
1
)
return
torch
.
empty
(
shape
,
device
=
'meta'
,
dtype
=
input
.
dtype
),
torch
.
empty
(
shape
,
device
=
'meta'
,
dtype
=
input
.
dtype
)
elif
isinstance
(
dim
,
torch
.
Tensor
):
# when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims
=
dim
.
dim
()
if
num_dims
in
[
0
,
1
]:
return
torch
.
empty_like
(
input
,
device
=
'meta'
)
else
:
raise
ValueError
(
f
"Expected dim to a 0D or 1D tensor but got
{
num_dims
}
dimensions"
)
else
:
return
torch
.
empty
([],
device
=
'meta'
,
dtype
=
input
.
dtype
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
cpu
)
def
torch_tensor_cpu
(
input
):
return
input
.
clone
()
@
meta_patched_function
.
register
(
torch
.
Tensor
.
cuda
)
def
torch_tensor_cuda
(
input
,
*
args
,
**
kwargs
):
return
input
.
clone
()
colossalai/fx/tracer/meta_patch/patched_module/__init__.py
0 → 100644
View file @
08f2920e
from
.activation_function
import
*
from
.convolution
import
*
from
.embedding
import
*
from
.linear
import
*
from
.normalization
import
*
from
.pooling
import
*
from
.rnn
import
*
\ No newline at end of file
colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Sigmoid
)
@
meta_patched_module
.
register
(
torch
.
nn
.
GELU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Tanh
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU6
)
@
meta_patched_module
.
register
(
torch
.
nn
.
PReLU
)
def
torch_nn_non_linear_act
(
self
,
input
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_module/convolution.py
0 → 100644
View file @
08f2920e
import
math
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv1d
)
def
torch_nn_conv1d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in
=
input
.
shape
[
-
1
]
c_out
=
self
.
out_channels
l_out
=
math
.
floor
((
l_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv2d
)
def
torch_nn_conv2d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
self
.
out_channels
h_out
=
math
.
floor
((
h_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
self
.
padding
[
1
]
-
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
-
1
)
/
self
.
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
Conv3d
)
def
torch_nn_conv3d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
self
.
out_channels
d_out
=
math
.
floor
((
d_in
+
2
*
self
.
padding
[
0
]
-
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
-
1
)
/
self
.
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
self
.
padding
[
1
]
-
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
-
1
)
/
self
.
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
self
.
padding
[
2
]
-
self
.
dilation
[
2
]
*
(
self
.
kernel_size
[
2
]
-
1
)
-
1
)
/
self
.
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ConvTranspose1d
)
def
torch_nn_convtranspose1d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in
=
input
.
shape
[
-
1
]
c_out
=
self
.
out_channels
l_out
=
math
.
floor
((
l_in
-
1
)
*
self
.
stride
[
0
]
-
2
*
self
.
padding
[
0
]
+
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
+
self
.
output_padding
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
c_out
,
l_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ConvTranspose2d
)
def
torch_nn_convtranspose2d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in
,
w_in
=
input
.
shape
[
-
2
:]
c_out
=
self
.
out_channels
h_out
=
math
.
floor
((
h_in
-
1
)
*
self
.
stride
[
0
]
-
2
*
self
.
padding
[
0
]
+
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
+
self
.
output_padding
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
self
.
stride
[
1
]
-
2
*
self
.
padding
[
1
]
+
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
+
self
.
output_padding
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
c_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ConvTranspose3d
)
def
torch_nn_convtranspose3d
(
self
,
input
):
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
c_out
=
self
.
out_channels
d_out
=
math
.
floor
((
d_in
-
1
)
*
self
.
stride
[
0
]
-
2
*
self
.
padding
[
0
]
+
self
.
dilation
[
0
]
*
(
self
.
kernel_size
[
0
]
-
1
)
+
self
.
output_padding
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
-
1
)
*
self
.
stride
[
1
]
-
2
*
self
.
padding
[
1
]
+
self
.
dilation
[
1
]
*
(
self
.
kernel_size
[
1
]
-
1
)
+
self
.
output_padding
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
-
1
)
*
self
.
stride
[
2
]
-
2
*
self
.
padding
[
2
]
+
self
.
dilation
[
2
]
*
(
self
.
kernel_size
[
2
]
-
1
)
+
self
.
output_padding
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
4
]
+
(
c_out
,
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_module/embedding.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Embedding
)
def
torch_nn_embedding
(
self
,
input
):
result_shape
=
input
.
shape
+
(
self
.
embedding_dim
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_module/linear.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
Linear
)
def
torch_nn_linear
(
self
,
input
):
last_dim
=
input
.
shape
[
-
1
]
assert
last_dim
==
self
.
in_features
,
f
'Expected hidden size
{
self
.
in_features
}
but got
{
last_dim
}
for the torch.nn.Linear patch'
return
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
self
.
out_features
,),
device
=
"meta"
)
colossalai/fx/tracer/meta_patch/patched_module/normalization.py
0 → 100644
View file @
08f2920e
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
LayerNorm
)
@
meta_patched_module
.
register
(
torch
.
nn
.
GroupNorm
)
@
meta_patched_module
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
BatchNorm3d
)
def
torch_nn_normalize
(
self
,
input
):
# check shape
if
isinstance
(
self
,
torch
.
nn
.
BatchNorm1d
):
assert
input
.
dim
()
in
[
2
,
3
]
elif
isinstance
(
self
,
torch
.
nn
.
BatchNorm2d
):
assert
input
.
dim
()
==
4
elif
isinstance
(
self
,
torch
.
nn
.
BatchNorm3d
):
assert
input
.
dim
()
==
5
# normalization maintain the same shape as the input
return
input
.
clone
()
try
:
import
apex
meta_patched_module
.
register
(
apex
.
normalization
.
FusedLayerNorm
)(
torch_nn_normalize
)
meta_patched_module
.
register
(
apex
.
normalization
.
FusedRMSNorm
)(
torch_nn_normalize
)
meta_patched_module
.
register
(
apex
.
normalization
.
MixedFusedLayerNorm
)(
torch_nn_normalize
)
meta_patched_module
.
register
(
apex
.
normalization
.
MixedFusedRMSNorm
)(
torch_nn_normalize
)
except
(
ImportError
,
AttributeError
):
pass
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
0 → 100644
View file @
08f2920e
import
math
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
AvgPool1d
)
def
torch_nn_avgpool1d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
2
,
3
],
f
'expected the input to have 2 or 3 dimensions, but got
{
num_dim
}
dimensions'
l_in
=
input
.
shape
[
-
1
]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
1
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
l_out
=
math
.
floor
((
l_in
+
2
*
padding
[
0
]
-
kernel_size
[
0
])
/
stride
[
0
]
+
1
)
result_shape
=
tuple
(
input
.
shape
[:
-
1
])
+
(
l_out
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AvgPool2d
)
def
torch_nn_avgpool2d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
3
,
4
],
f
'expected the input to have 3 or 4 dimensions, but got
{
num_dim
}
dimensions'
h_in
,
w_in
=
input
.
shape
[
-
2
:]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
2
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
0
]
-
kernel_size
[
0
])
/
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
1
]
-
kernel_size
[
1
])
/
stride
[
1
]
+
1
)
result_shape
=
tuple
(
input
.
shape
[:
-
2
])
+
(
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AvgPool3d
)
def
torch_nn_avgpool3d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
4
,
5
],
f
'expected the input to have 4 or 5 dimensions, but got
{
num_dim
}
dimensions'
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
3
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
d_out
=
math
.
floor
((
d_in
+
2
*
padding
[
0
]
-
kernel_size
[
0
])
/
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
kernel_size
[
1
])
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
kernel_size
[
2
])
/
stride
[
2
]
+
1
)
result_shape
=
tuple
(
input
.
shape
[:
-
3
])
+
(
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
MaxPool1d
)
def
torch_nn_maxpool1d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
2
,
3
],
f
'expected the input to have 2 or 3 dimensions, but got
{
num_dim
}
dimensions'
l_in
=
input
.
shape
[
-
1
]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
1
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
dilation
=
_convert_int_to_list
(
self
.
dilation
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
l_out
=
math
.
floor
((
l_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
result_shape
=
tuple
(
input
.
shape
[:
-
1
])
+
(
l_out
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
MaxPool2d
)
def
torch_nn_maxpool2d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
3
,
4
],
f
'expected the input to have 3 or 4 dimensions, but got
{
num_dim
}
dimensions'
h_in
,
w_in
=
input
.
shape
[
-
2
:]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
2
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
dilation
=
_convert_int_to_list
(
self
.
dilation
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
result_shape
=
tuple
(
input
.
shape
[:
-
2
])
+
(
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
MaxPool3d
)
def
torch_nn_maxpool3d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
4
,
5
],
f
'expected the input to have 4 or 5 dimensions, but got
{
num_dim
}
dimensions'
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
3
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
dilation
=
_convert_int_to_list
(
self
.
dilation
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
d_out
=
math
.
floor
((
d_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
-
1
)
/
stride
[
2
]
+
1
)
result_shape
=
tuple
(
input
.
shape
[:
-
3
])
+
(
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool1d
)
def
torch_nn_adapative_pooling_1d
(
self
,
input
):
assert
input
.
dim
()
in
[
2
,
3
]
if
isinstance
(
self
.
output_size
,
int
):
output_size
=
(
self
.
output_size
,)
else
:
output_size
=
self
.
output_size
result_shape
=
tuple
(
input
.
shape
[:
-
1
])
+
output_size
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool2d
)
def
torch_nn_adapative_pooling_2d
(
self
,
input
):
assert
input
.
dim
()
in
[
3
,
4
]
if
isinstance
(
self
.
output_size
,
int
):
output_size
=
(
self
.
output_size
,)
*
2
else
:
output_size
=
self
.
output_size
result_shape
=
tuple
(
input
.
shape
[:
-
2
])
+
output_size
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool3d
)
def
torch_nn_adapative_pooling_3d
(
self
,
input
):
assert
input
.
dim
()
in
[
4
,
5
]
if
isinstance
(
self
.
output_size
,
int
):
output_size
=
(
self
.
output_size
,)
*
3
else
:
output_size
=
self
.
output_size
result_shape
=
tuple
(
input
.
shape
[:
-
3
])
+
output_size
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_module/rnn.py
0 → 100644
View file @
08f2920e
from
typing
import
Optional
import
torch
from
...registry
import
meta_patched_module
@
meta_patched_module
.
register
(
torch
.
nn
.
GRU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
RNN
)
def
torch_nn_rnn
(
self
,
input
,
hx
):
assert
input
.
shape
[
-
1
]
==
self
.
input_size
,
f
'Expected input to have input size
{
self
.
input_size
}
but got
{
input
.
shape
[
-
1
]
}
for the torch.nn.RNN patch'
assert
hx
.
shape
[
-
1
]
==
self
.
hidden_size
,
f
'Expected hx to have hidden size
{
self
.
hidden_size
}
but got
{
hx
.
shape
[
-
1
]
}
for the torch.nn.RNN patch'
d
=
2
if
self
.
bidirectional
else
1
return
torch
.
empty
(
input
.
shape
[:
-
1
]
+
(
self
.
hidden_size
*
d
,),
device
=
"meta"
),
hx
colossalai/fx/tracer/registry.py
0 → 100644
View file @
08f2920e
class
PatchRegistry
:
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
store
=
{}
def
register
(
self
,
source
):
def
wrapper
(
func
):
self
.
store
[
source
]
=
func
return
func
return
wrapper
def
get
(
self
,
source
):
assert
source
in
self
.
store
target
=
self
.
store
[
source
]
return
target
def
has
(
self
,
source
):
return
source
in
self
.
store
meta_patched_function
=
PatchRegistry
(
name
=
'patched_functions_for_meta_execution'
)
meta_patched_module
=
PatchRegistry
(
name
=
'patched_modules_for_meta_execution'
)
bias_addition_function
=
PatchRegistry
(
name
=
'patched_function_for_bias_addition'
)
bias_addition_module
=
PatchRegistry
(
name
=
'patched_module_for_bias_addition'
)
bias_addition_method
=
PatchRegistry
(
name
=
'patched_method_for_bias_addition'
)
colossalai/fx/tracer/tracer.py
0 → 100644
View file @
08f2920e
#!/usr/bin/env python
"""
tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer
"""
import
enum
import
functools
import
inspect
import
operator
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Optional
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.fx
import
Node
,
Tracer
from
torch.fx.graph
import
Graph
,
magic_methods
,
reflectable_magic_methods
from
torch.fx.proxy
import
ParameterProxy
,
Proxy
from
..proxy
import
ColoProxy
from
._tracer_utils
import
compute_meta_data_for_functions_proxy
,
extract_meta
,
is_element_in_list
from
.bias_addition_patch
import
func_to_func_dict
,
method_to_func_dict
,
module_to_func_dict
from
.registry
import
(
bias_addition_function
,
bias_addition_method
,
bias_addition_module
,
meta_patched_function
,
meta_patched_module
,
)
__all__
=
[
'ColoTracer'
]
class
TracerType
(
enum
.
Enum
):
DEFAULT
=
1
META
=
2
class
ColoTracer
(
Tracer
):
"""
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
This tracer is initialized in the same way as the original torch.fx.Tracer.
Usage::
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x, y):
x1 = self.linear1(x)
y1 = self.linear2(y)
if x1.dim() == 2:
return x1 + y1
else:
return x1 - y1
model = Model()
tracer = ColoTracer()
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
"""
def
__init__
(
self
,
trace_act_ckpt
:
bool
=
False
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
tracer_type
=
TracerType
.
META
self
.
proxy_cls
=
ColoProxy
# whether the tracer will record the usage of torch.utils.checkpoint
self
.
trace_act_ckpt
=
trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self
.
inside_torch_checkpoint_func
=
False
self
.
act_ckpt_region_count
=
0
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes
:
bool
=
True
_TORCH_METHODS_TO_PATCH
=
[
"arange"
,
"zeros"
,
"ones"
,
"full"
,
"full_like"
,
"eye"
,
"empty"
,
"tensor"
,
"finfo"
]
def
create_proxy
(
self
,
kind
,
target
,
args
,
kwargs
,
name
=
None
,
type_expr
=
None
,
proxy_factory_fn
=
None
)
->
ColoProxy
:
"""
Create a proxy for different kinds of operations.
"""
if
self
.
tracer_type
==
TracerType
.
DEFAULT
:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
proxy
=
super
().
create_proxy
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
return
proxy
# if graph is traced for auto parallelism module, some extra node will be added during
# graph construction to deal with the compatability between bias addition and all reduce.
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
# to create node on computation graph
origin_arguments
=
(
kind
,
target
,
args
,
kwargs
,
name
,
type_expr
,
proxy_factory_fn
)
# dispatch the arguments generator depending on the kind and target in origin arguments.
args_metas
,
_
=
extract_meta
(
*
args
,
**
kwargs
)
handle
=
None
if
kind
==
"call_function"
:
if
bias_addition_function
.
has
(
target
):
if
target
==
torch
.
nn
.
functional
.
linear
:
if
'bias'
in
kwargs
and
kwargs
[
'bias'
]
is
not
None
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
else
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
bias_addition_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
if
bias_addition_method
.
has
(
method
):
function_to_substitute
=
method_to_func_dict
[
method
]
handle
=
bias_addition_method
.
get
(
method
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
elif
kind
==
"call_module"
:
if
not
hasattr
(
self
,
"orig_forward"
):
raise
AttributeError
(
f
"
{
self
}
does not have an attribute called orig_forward"
)
self
.
_disable_module_getattr
=
True
try
:
mod
=
self
.
root
.
get_submodule
(
target
)
mod_type
=
type
(
mod
)
if
bias_addition_module
.
has
(
mod_type
)
and
mod
.
bias
is
not
None
:
function_to_substitute
=
module_to_func_dict
[
mod_type
]
handle
=
bias_addition_module
.
get
(
mod_type
)(
self
,
target
,
args
,
kwargs
,
function_to_substitute
)
finally
:
self
.
_disable_module_getattr
=
False
if
handle
is
not
None
:
return
handle
.
generate
()
# create nodes using patched arguments
proxy
=
super
().
create_proxy
(
*
origin_arguments
)
proxy
:
ColoProxy
meta_out
=
self
.
_meta_data_computing
(
kind
,
target
,
args
,
kwargs
,
)
proxy
.
meta_data
=
meta_out
return
proxy
def
_module_getattr
(
self
,
attr
,
attr_val
,
parameter_proxy_cache
):
if
getattr
(
self
,
"_disable_module_getattr"
,
False
):
return
attr_val
else
:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def
maybe_get_proxy_for_attr
(
attr_val
,
collection_to_search
,
parameter_proxy_cache
):
for
n
,
p
in
collection_to_search
:
if
attr_val
is
p
:
if
n
not
in
parameter_proxy_cache
:
kwargs
=
{}
if
"proxy_factory_fn"
in
inspect
.
signature
(
self
.
create_proxy
).
parameters
:
kwargs
[
"proxy_factory_fn"
]
=
(
None
if
not
self
.
param_shapes_constant
else
lambda
node
:
ParameterProxy
(
self
,
node
,
n
,
attr_val
))
val_proxy
=
self
.
create_proxy
(
"get_attr"
,
n
,
(),
{},
**
kwargs
)
# type: ignore[arg-type]
parameter_proxy_cache
[
n
]
=
val_proxy
return
parameter_proxy_cache
[
n
]
return
None
if
isinstance
(
attr_val
,
torch
.
nn
.
Parameter
):
maybe_parameter_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_parameters
(),
parameter_proxy_cache
)
if
maybe_parameter_proxy
is
not
None
:
return
maybe_parameter_proxy
if
self
.
proxy_buffer_attributes
and
isinstance
(
attr_val
,
torch
.
Tensor
):
maybe_buffer_proxy
=
maybe_get_proxy_for_attr
(
attr_val
,
self
.
root
.
named_buffers
(),
parameter_proxy_cache
)
if
maybe_buffer_proxy
is
not
None
:
return
maybe_buffer_proxy
return
attr_val
def
call_module
(
self
,
m
,
forward
,
args
,
kwargs
):
self
.
orig_forward
=
forward
module_qualified_name
=
self
.
path_of_module
(
m
)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if
meta_patched_module
.
has
(
m
.
__class__
)
or
self
.
is_leaf_module
(
m
,
module_qualified_name
):
return
self
.
create_proxy
(
'call_module'
,
module_qualified_name
,
args
,
kwargs
)
else
:
return
forward
(
*
args
,
**
kwargs
)
def
proxy
(
self
,
node
)
->
Proxy
:
"""
Returns a ColoProxy object.
"""
return
self
.
proxy_cls
(
node
,
self
)
def
_configure_tracer_type
(
self
,
tracer_type
:
TracerType
):
if
tracer_type
==
TracerType
.
DEFAULT
:
self
.
proxy_cls
=
Proxy
self
.
tracer_type
=
TracerType
.
DEFAULT
elif
tracer_type
==
TracerType
.
META
:
self
.
proxy_cls
=
ColoProxy
self
.
tracer_type
=
TracerType
.
META
else
:
raise
ValueError
(
f
"Unrecognised tracer type
{
tracer_type
}
"
)
def
_meta_data_computing
(
self
,
kind
,
target
,
args
,
kwargs
):
if
kind
==
"placeholder"
and
target
in
self
.
meta_args
and
self
.
meta_args
[
target
].
is_meta
:
meta_out
=
self
.
meta_args
[
target
]
return
meta_out
if
target
in
self
.
orig_torch_tensor_methods
:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if
"device"
in
kwargs
:
kwargs
[
"device"
]
=
"meta"
try
:
args_metas
,
kwargs_metas
=
extract_meta
(
*
args
,
**
kwargs
)
if
kind
==
"call_function"
:
# fetch patched function
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
else
:
meta_target
=
target
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
if
isinstance
(
meta_out
,
torch
.
Tensor
):
meta_out
=
meta_out
.
to
(
device
=
"meta"
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
# fetch patched method
if
meta_patched_function
.
has
(
method
):
meta_target
=
meta_patched_function
.
get
(
method
)
else
:
meta_target
=
method
meta_out
=
meta_target
(
*
args_metas
,
**
kwargs_metas
)
elif
kind
==
"call_module"
:
if
not
hasattr
(
self
,
"orig_forward"
):
raise
AttributeError
(
f
"
{
self
}
does not have an attribute called orig_forward"
)
self
.
_disable_module_getattr
=
True
try
:
mod
=
self
.
root
.
get_submodule
(
target
)
mod_type
=
type
(
mod
)
if
meta_patched_module
.
has
(
mod_type
):
meta_out
=
meta_patched_module
.
get
(
mod_type
)(
mod
,
*
args_metas
,
**
kwargs_metas
)
else
:
meta_out
=
self
.
orig_forward
(
*
args_metas
,
**
kwargs_metas
)
finally
:
self
.
_disable_module_getattr
=
False
elif
kind
==
"get_attr"
:
self
.
_disable_module_getattr
=
True
try
:
attr_itr
=
self
.
root
atoms
=
target
.
split
(
"."
)
for
atom
in
atoms
:
attr_itr
=
getattr
(
attr_itr
,
atom
)
if
isinstance
(
attr_itr
,
torch
.
nn
.
parameter
.
Parameter
):
meta_out
=
torch
.
nn
.
Parameter
(
attr_itr
.
to
(
device
=
"meta"
))
elif
isinstance
(
attr_itr
,
torch
.
Tensor
):
meta_out
=
attr_itr
.
to
(
device
=
"meta"
)
else
:
meta_out
=
attr_itr
finally
:
self
.
_disable_module_getattr
=
False
else
:
return
None
except
Exception
as
e
:
raise
RuntimeError
(
f
"Could not compute metadata for
{
kind
}
target
{
target
}
:
{
e
}
"
)
return
meta_out
def
trace
(
self
,
root
:
nn
.
Module
,
concrete_args
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
)
->
Graph
:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
Args:
root (nn.Module): a `nn.Module` object to trace the computation graph
meta_args (Optional[Dict[str, Tensor]]): the meta tensor arguments used to trace the computation graph.
These arguments are the sample data fed to the model during actual computation, but just converted to meta tensors.
concrete_args (Optional[Dict[str, Tensor]]): the concrete arguments that should not be treated as Proxies.
"""
if
meta_args
is
None
:
meta_args
=
{}
if
concrete_args
is
None
:
concrete_args
=
{}
if
len
(
meta_args
)
==
0
:
self
.
_configure_tracer_type
(
TracerType
.
DEFAULT
)
else
:
self
.
_configure_tracer_type
(
TracerType
.
META
)
# check concrete and meta args have valid names
sig
=
inspect
.
signature
(
root
.
forward
)
sig_names
=
set
(
sig
.
parameters
.
keys
())
meta_arg_names
=
set
(
meta_args
.
keys
())
# update concrete args with default values
non_meta_arg_names
=
sig_names
-
meta_arg_names
for
k
,
v
in
sig
.
parameters
.
items
():
if
k
in
non_meta_arg_names
and
\
k
not
in
concrete_args
and
\
v
.
default
is
not
inspect
.
Parameter
.
empty
:
concrete_args
[
k
]
=
v
.
default
# get non concrete arg names
concrete_arg_names
=
set
(
concrete_args
.
keys
())
non_concrete_arg_names
=
sig_names
-
concrete_arg_names
def
_check_arg_name_valid
(
names
):
success
,
element
=
is_element_in_list
(
names
,
sig_names
)
if
not
success
:
raise
KeyError
(
f
"argument
{
element
}
is not found in the signature of
{
root
.
__class__
.
__name__
}
's forward function"
)
_check_arg_name_valid
(
meta_arg_names
)
_check_arg_name_valid
(
concrete_arg_names
)
# assign as attributed for late reference
def
_check_kwargs
(
kwargs
,
should_be_meta
:
bool
):
for
k
,
v
in
kwargs
.
items
():
if
not
should_be_meta
:
assert
not
torch
.
is_tensor
(
v
)
or
not
v
.
is_meta
,
\
f
'Expected the
{
k
}
not to be a meta tensor, please check the args passed to the tracer'
else
:
assert
v
.
is_meta
==
should_be_meta
,
\
f
'Expected the is_meta attribute of
{
k
}
to be
{
should_be_meta
}
, but got
{
v
.
is_meta
}
, please check the args passed to the tracer'
_check_kwargs
(
concrete_args
,
should_be_meta
=
False
)
_check_kwargs
(
meta_args
,
should_be_meta
=
True
)
self
.
concrete_args
=
concrete_args
self
.
meta_args
=
meta_args
self
.
patched_torch_tensor_methods
=
{}
if
self
.
tracer_type
==
TracerType
.
META
:
# wrap the torch tensor constructing methods so that they are captured in the graph
self
.
patched_torch_tensor_methods
=
{
target
:
wrap_tensor_constructor_method
(
getattr
(
torch
,
target
))
for
target
in
self
.
_TORCH_METHODS_TO_PATCH
}
# patch these methods to replace their original use
for
name
,
(
wrapper
,
orig
)
in
self
.
patched_torch_tensor_methods
.
items
():
setattr
(
torch
,
name
,
wrapper
)
# cache these methods so that we can detect whether a method call
# should be patched during tracing
self
.
orig_torch_tensor_methods
=
[
val
[
1
]
for
val
in
self
.
patched_torch_tensor_methods
.
values
()]
try
:
# to track the usage of torch.utils.checkpoint
with
self
.
trace_activation_checkpoint
(
enabled
=
self
.
trace_act_ckpt
):
self
.
graph
=
super
().
trace
(
root
,
concrete_args
=
concrete_args
)
finally
:
# recover the patched methods
for
name
,
(
_
,
orig
)
in
self
.
patched_torch_tensor_methods
.
items
():
setattr
(
torch
,
name
,
orig
)
if
self
.
tracer_type
==
TracerType
.
DEFAULT
:
return
self
.
graph
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for
node
in
self
.
graph
.
nodes
:
if
node
.
op
==
"placeholder"
:
# Removing default values for inputs as the forward pass will fail with them.
if
node
.
target
in
non_concrete_arg_names
:
node
.
args
=
()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node
.
type
=
torch
.
Tensor
# It is a concrete arg so it is not used and should be removed.
else
:
if
hasattr
(
torch
.
fx
.
_symbolic_trace
,
"_assert_is_none"
):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete
=
[]
for
user
in
node
.
users
:
if
user
.
target
==
torch
.
fx
.
_symbolic_trace
.
_assert_is_none
:
to_delete
.
append
(
user
)
for
user
in
to_delete
:
self
.
graph
.
erase_node
(
user
)
self
.
graph
.
erase_node
(
node
)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if
node
.
op
==
"output"
:
node
.
type
=
None
return
self
.
graph
@
contextmanager
def
trace_activation_checkpoint
(
self
,
enabled
:
bool
):
if
enabled
:
orig_ckpt_func
=
torch
.
utils
.
checkpoint
.
CheckpointFunction
class
PatchedCheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
# signal that the current tracing occurs within activaton checkpoint part
self
.
inside_torch_checkpoint_func
=
True
out
=
run_function
(
*
args
)
self
.
inside_torch_checkpoint_func
=
False
self
.
act_ckpt_region_count
+=
1
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_outputs
:
Any
)
->
Any
:
raise
NotImplementedError
(
"We do not implement the backward pass as we only trace the forward pass."
)
# override the checkpoint function
torch
.
utils
.
checkpoint
.
CheckpointFunction
=
PatchedCheckpointFunction
yield
if
enabled
:
# recover the checkpoint function upon exit
torch
.
utils
.
checkpoint
.
CheckpointFunction
=
orig_ckpt_func
def
create_node
(
self
,
*
args
,
**
kwargs
)
->
Node
:
node
=
super
().
create_node
(
*
args
,
**
kwargs
)
if
self
.
inside_torch_checkpoint_func
:
# annotate the activation checkpoint module
node
.
meta
[
'activation_checkpoint'
]
=
self
.
act_ckpt_region_count
return
node
def
wrap_tensor_constructor_method
(
target
):
def
look_for_proxy
(
*
args
,
**
kwargs
):
# find in pos vars
for
arg
in
args
:
if
isinstance
(
arg
,
Proxy
):
return
arg
if
isinstance
(
arg
,
(
tuple
,
list
)):
return
look_for_proxy
(
*
arg
)
# find in keyword vars
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
Proxy
):
return
v
if
isinstance
(
v
,
(
tuple
,
list
)):
return
look_for_proxy
(
*
v
)
return
None
@
functools
.
wraps
(
target
)
def
wrapper
(
*
args
,
**
kwargs
):
proxy
=
look_for_proxy
(
*
args
,
**
kwargs
)
if
proxy
is
not
None
:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
colo_proxy
=
proxy
.
tracer
.
create_proxy
(
"call_function"
,
target
,
args
,
kwargs
)
if
not
isinstance
(
colo_proxy
,
ColoProxy
):
meta_out
=
compute_meta_data_for_functions_proxy
(
target
,
args
,
kwargs
)
colo_proxy
=
ColoProxy
(
proxy
.
node
)
colo_proxy
.
meta_data
=
meta_out
return
colo_proxy
else
:
# this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static
return
target
(
*
args
,
**
kwargs
)
return
wrapper
,
target
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
# and add meta_data attribute to the created proxy.
for
method
in
magic_methods
:
def
_scope
(
method
):
def
impl
(
*
args
,
**
kwargs
):
tracer
=
args
[
0
].
tracer
target
=
getattr
(
operator
,
method
)
proxy
=
tracer
.
create_proxy
(
'call_function'
,
target
,
args
,
kwargs
)
if
not
isinstance
(
proxy
,
ColoProxy
):
meta_out
=
compute_meta_data_for_functions_proxy
(
target
,
args
,
kwargs
)
proxy
=
ColoProxy
(
proxy
.
node
)
proxy
.
meta_data
=
meta_out
return
proxy
impl
.
__name__
=
method
as_magic
=
f
'__
{
method
.
strip
(
"_"
)
}
__'
setattr
(
ColoProxy
,
as_magic
,
impl
)
_scope
(
method
)
def
_define_reflectable
(
orig_method_name
):
method_name
=
f
'__r
{
orig_method_name
.
strip
(
"_"
)
}
__'
def
impl
(
self
,
rhs
):
target
=
getattr
(
operator
,
orig_method_name
)
proxy
=
self
.
tracer
.
create_proxy
(
'call_function'
,
target
,
(
rhs
,
self
),
{})
if
not
isinstance
(
proxy
,
ColoProxy
):
meta_out
=
compute_meta_data_for_functions_proxy
(
target
,
*
(
rhs
,
self
),
{})
proxy
=
ColoProxy
(
proxy
.
node
)
proxy
.
meta_data
=
meta_out
return
proxy
impl
.
__name__
=
method_name
impl
.
__qualname__
=
method_name
setattr
(
ColoProxy
,
method_name
,
impl
)
for
orig_method_name
in
reflectable_magic_methods
:
_define_reflectable
(
orig_method_name
)
colossalai/gemini/__init__.py
0 → 100644
View file @
08f2920e
from
.chunk
import
ChunkManager
,
TensorInfo
,
TensorState
,
search_chunk_configuration
from
.gemini_mgr
import
GeminiManager
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
__all__
=
[
'StatefulTensorMgr'
,
'TensorPlacementPolicyFactory'
,
'GeminiManager'
,
'TensorInfo'
,
'TensorState'
,
'ChunkManager'
,
'search_chunk_configuration'
]
colossalai/gemini/chunk/__init__.py
0 → 100644
View file @
08f2920e
from
.chunk
import
Chunk
,
ChunkFullError
,
TensorInfo
,
TensorState
from
.manager
import
ChunkManager
from
.search_utils
import
classify_params_by_dp_degree
,
search_chunk_configuration
from
.utils
import
init_chunk_manager
__all__
=
[
'Chunk'
,
'ChunkManager'
,
'classify_params_by_dp_degree'
,
'search_chunk_configuration'
,
'init_chunk_manager'
]
Prev
1
…
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment