Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ae861519
Unverified
Commit
ae861519
authored
Jun 22, 2022
by
ver217
Committed by
GitHub
Jun 22, 2022
Browse files
[tensor] add more element-wise ops (#1155)
* add more element-wise ops * update test_op * polish unit test
parent
e8c34eed
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
248 additions
and
6 deletions
+248
-6
colossalai/nn/_ops/element_wise.py
colossalai/nn/_ops/element_wise.py
+205
-5
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+43
-1
No files found.
colossalai/nn/_ops/element_wise.py
View file @
ae861519
from
copy
import
copy
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
copy
import
copy
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
from
._utils
import
GeneralTensor
...
...
@@ -21,8 +23,206 @@ def register_elementwise_op(op):
return
ColoTensor
.
from_torch_tensor
(
output
)
register_elementwise_op
(
torch
.
nn
.
functional
.
gelu
)
register_elementwise_op
(
torch
.
nn
.
functional
.
relu
)
# Tensor op
register_elementwise_op
(
Tensor
.
abs
)
register_elementwise_op
(
Tensor
.
absolute
)
register_elementwise_op
(
Tensor
.
acos
)
register_elementwise_op
(
Tensor
.
arccos
)
register_elementwise_op
(
Tensor
.
angle
)
register_elementwise_op
(
Tensor
.
asin
)
register_elementwise_op
(
Tensor
.
arcsin
)
register_elementwise_op
(
Tensor
.
atan
)
register_elementwise_op
(
Tensor
.
arctan
)
register_elementwise_op
(
Tensor
.
all
)
register_elementwise_op
(
Tensor
.
any
)
register_elementwise_op
(
Tensor
.
bernoulli
)
register_elementwise_op
(
Tensor
.
bfloat16
)
register_elementwise_op
(
Tensor
.
bitwise_not
)
register_elementwise_op
(
Tensor
.
bool
)
register_elementwise_op
(
Tensor
.
byte
)
register_elementwise_op
(
Tensor
.
ceil
)
register_elementwise_op
(
Tensor
.
char
)
register_elementwise_op
(
Tensor
.
clamp
)
register_elementwise_op
(
Tensor
.
clamp_max
)
register_elementwise_op
(
Tensor
.
clamp_min
)
register_elementwise_op
(
Tensor
.
clip
)
register_elementwise_op
(
Tensor
.
clone
)
register_elementwise_op
(
Tensor
.
contiguous
)
register_elementwise_op
(
Tensor
.
copysign
)
register_elementwise_op
(
Tensor
.
cos
)
register_elementwise_op
(
Tensor
.
cosh
)
register_elementwise_op
(
Tensor
.
acosh
)
register_elementwise_op
(
Tensor
.
arccosh
)
register_elementwise_op
(
Tensor
.
cpu
)
register_elementwise_op
(
Tensor
.
cuda
)
register_elementwise_op
(
Tensor
.
deg2rad
)
register_elementwise_op
(
Tensor
.
detach
)
register_elementwise_op
(
Tensor
.
digamma
)
register_elementwise_op
(
Tensor
.
double
)
register_elementwise_op
(
Tensor
.
erf
)
register_elementwise_op
(
Tensor
.
erfc
)
register_elementwise_op
(
Tensor
.
erfinv
)
register_elementwise_op
(
Tensor
.
exp
)
register_elementwise_op
(
Tensor
.
expm1
)
register_elementwise_op
(
Tensor
.
fix
)
register_elementwise_op
(
Tensor
.
trunc
)
register_elementwise_op
(
Tensor
.
float
)
register_elementwise_op
(
Tensor
.
float_power
)
register_elementwise_op
(
Tensor
.
floor
)
register_elementwise_op
(
Tensor
.
frac
)
register_elementwise_op
(
Tensor
.
half
)
register_elementwise_op
(
Tensor
.
hardshrink
)
register_elementwise_op
(
Tensor
.
heaviside
)
register_elementwise_op
(
Tensor
.
i0
)
register_elementwise_op
(
Tensor
.
int
)
register_elementwise_op
(
Tensor
.
isfinite
)
register_elementwise_op
(
Tensor
.
isinf
)
register_elementwise_op
(
Tensor
.
isposinf
)
register_elementwise_op
(
Tensor
.
isneginf
)
register_elementwise_op
(
Tensor
.
isnan
)
register_elementwise_op
(
Tensor
.
lgamma
)
register_elementwise_op
(
Tensor
.
log
)
register_elementwise_op
(
Tensor
.
log10
)
register_elementwise_op
(
Tensor
.
log1p
)
register_elementwise_op
(
Tensor
.
log2
)
register_elementwise_op
(
Tensor
.
logical_not
)
register_elementwise_op
(
Tensor
.
logit
)
register_elementwise_op
(
Tensor
.
long
)
register_elementwise_op
(
Tensor
.
nan_to_num
)
register_elementwise_op
(
Tensor
.
neg
)
register_elementwise_op
(
Tensor
.
negative
)
register_elementwise_op
(
Tensor
.
positive
)
register_elementwise_op
(
Tensor
.
pow
)
register_elementwise_op
(
Tensor
.
rad2deg
)
register_elementwise_op
(
Tensor
.
reciprocal
)
register_elementwise_op
(
Tensor
.
round
)
register_elementwise_op
(
Tensor
.
rsqrt
)
register_elementwise_op
(
Tensor
.
short
)
register_elementwise_op
(
Tensor
.
sigmoid
)
register_elementwise_op
(
Tensor
.
sign
)
register_elementwise_op
(
Tensor
.
signbit
)
register_elementwise_op
(
Tensor
.
sgn
)
register_elementwise_op
(
Tensor
.
sin
)
register_elementwise_op
(
Tensor
.
sinc
)
register_elementwise_op
(
Tensor
.
sinh
)
register_elementwise_op
(
Tensor
.
asinh
)
register_elementwise_op
(
Tensor
.
arcsinh
)
register_elementwise_op
(
Tensor
.
sqrt
)
register_elementwise_op
(
Tensor
.
square
)
register_elementwise_op
(
Tensor
.
to
)
register_elementwise_op
(
Tensor
.
tan
)
register_elementwise_op
(
Tensor
.
tanh
)
register_elementwise_op
(
Tensor
.
atanh
)
register_elementwise_op
(
Tensor
.
arctanh
)
register_elementwise_op
(
Tensor
.
type
)
register_elementwise_op
(
Tensor
.
type_as
)
# torch OP
register_elementwise_op
(
torch
.
abs
)
register_elementwise_op
(
torch
.
absolute
)
register_elementwise_op
(
torch
.
acos
)
register_elementwise_op
(
torch
.
arccos
)
register_elementwise_op
(
torch
.
angle
)
register_elementwise_op
(
torch
.
asin
)
register_elementwise_op
(
torch
.
arcsin
)
register_elementwise_op
(
torch
.
atan
)
register_elementwise_op
(
torch
.
arctan
)
register_elementwise_op
(
torch
.
all
)
register_elementwise_op
(
torch
.
any
)
register_elementwise_op
(
torch
.
bernoulli
)
register_elementwise_op
(
torch
.
bitwise_not
)
register_elementwise_op
(
torch
.
ceil
)
register_elementwise_op
(
torch
.
clamp
)
register_elementwise_op
(
torch
.
clamp_max
)
register_elementwise_op
(
torch
.
clamp_min
)
register_elementwise_op
(
torch
.
clip
)
register_elementwise_op
(
torch
.
clone
)
register_elementwise_op
(
torch
.
Tensor
.
clone
)
register_elementwise_op
(
torch
.
Tensor
.
detach
)
register_elementwise_op
(
torch
.
copysign
)
register_elementwise_op
(
torch
.
cos
)
register_elementwise_op
(
torch
.
cosh
)
register_elementwise_op
(
torch
.
acosh
)
register_elementwise_op
(
torch
.
arccosh
)
register_elementwise_op
(
torch
.
deg2rad
)
register_elementwise_op
(
torch
.
digamma
)
register_elementwise_op
(
torch
.
erf
)
register_elementwise_op
(
torch
.
erfc
)
register_elementwise_op
(
torch
.
erfinv
)
register_elementwise_op
(
torch
.
exp
)
register_elementwise_op
(
torch
.
expm1
)
register_elementwise_op
(
torch
.
fix
)
register_elementwise_op
(
torch
.
trunc
)
register_elementwise_op
(
torch
.
float_power
)
register_elementwise_op
(
torch
.
floor
)
register_elementwise_op
(
torch
.
frac
)
register_elementwise_op
(
torch
.
hardshrink
)
register_elementwise_op
(
torch
.
heaviside
)
register_elementwise_op
(
torch
.
i0
)
register_elementwise_op
(
torch
.
isfinite
)
register_elementwise_op
(
torch
.
isinf
)
register_elementwise_op
(
torch
.
isposinf
)
register_elementwise_op
(
torch
.
isneginf
)
register_elementwise_op
(
torch
.
isnan
)
register_elementwise_op
(
torch
.
lgamma
)
register_elementwise_op
(
torch
.
log
)
register_elementwise_op
(
torch
.
log10
)
register_elementwise_op
(
torch
.
log1p
)
register_elementwise_op
(
torch
.
log2
)
register_elementwise_op
(
torch
.
logical_not
)
register_elementwise_op
(
torch
.
logit
)
register_elementwise_op
(
torch
.
nan_to_num
)
register_elementwise_op
(
torch
.
neg
)
register_elementwise_op
(
torch
.
negative
)
register_elementwise_op
(
torch
.
positive
)
register_elementwise_op
(
torch
.
pow
)
register_elementwise_op
(
torch
.
rad2deg
)
register_elementwise_op
(
torch
.
reciprocal
)
register_elementwise_op
(
torch
.
round
)
register_elementwise_op
(
torch
.
rsqrt
)
register_elementwise_op
(
torch
.
sigmoid
)
register_elementwise_op
(
torch
.
sign
)
register_elementwise_op
(
torch
.
signbit
)
register_elementwise_op
(
torch
.
sgn
)
register_elementwise_op
(
torch
.
sin
)
register_elementwise_op
(
torch
.
sinc
)
register_elementwise_op
(
torch
.
sinh
)
register_elementwise_op
(
torch
.
asinh
)
register_elementwise_op
(
torch
.
arcsinh
)
register_elementwise_op
(
torch
.
sqrt
)
register_elementwise_op
(
torch
.
square
)
register_elementwise_op
(
torch
.
tan
)
register_elementwise_op
(
torch
.
tanh
)
register_elementwise_op
(
torch
.
atanh
)
register_elementwise_op
(
torch
.
arctanh
)
# nn.functional OP
register_elementwise_op
(
F
.
threshold
)
register_elementwise_op
(
F
.
relu
)
register_elementwise_op
(
F
.
hardtanh
)
register_elementwise_op
(
F
.
hardswish
)
register_elementwise_op
(
F
.
relu6
)
register_elementwise_op
(
F
.
elu
)
register_elementwise_op
(
F
.
selu
)
register_elementwise_op
(
F
.
celu
)
register_elementwise_op
(
F
.
leaky_relu
)
register_elementwise_op
(
F
.
prelu
)
register_elementwise_op
(
F
.
rrelu
)
register_elementwise_op
(
F
.
gelu
)
register_elementwise_op
(
F
.
logsigmoid
)
register_elementwise_op
(
F
.
hardshrink
)
register_elementwise_op
(
F
.
tanhshrink
)
register_elementwise_op
(
F
.
softsign
)
register_elementwise_op
(
F
.
softplus
)
register_elementwise_op
(
F
.
softmin
)
register_elementwise_op
(
F
.
softmax
)
register_elementwise_op
(
F
.
softshrink
)
register_elementwise_op
(
F
.
gumbel_softmax
)
register_elementwise_op
(
F
.
log_softmax
)
register_elementwise_op
(
F
.
tanh
)
register_elementwise_op
(
F
.
sigmoid
)
register_elementwise_op
(
F
.
hardsigmoid
)
register_elementwise_op
(
F
.
silu
)
register_elementwise_op
(
F
.
mish
)
# TODO(ver217): dropout handles seed
register_elementwise_op
(
F
.
dropout
)
register_elementwise_op
(
F
.
alpha_dropout
)
register_elementwise_op
(
F
.
feature_alpha_dropout
)
tests/test_tensor/test_op.py
View file @
ae861519
import
torch
import
pytest
import
colossalai
import
torch.nn.functional
as
F
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.utils
import
get_current_device
from
torch.nn
import
Parameter
import
torch.nn.functional
as
F
from
torch.distributed.distributed_c10d
import
_get_default_group
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
TensorSpec
def
test_layernorm
():
...
...
@@ -26,8 +34,42 @@ def test_layernorm():
assert
torch
.
allclose
(
ln_op
.
weight
.
grad
,
weight
.
grad
)
def
check_spec_eq
(
tensor
,
other
):
assert
isinstance
(
tensor
,
ColoTensor
)
and
isinstance
(
other
,
ColoTensor
)
for
k
in
dir
(
tensor
.
spec
.
dist_spec
):
if
not
k
.
startswith
(
'__'
):
assert
hasattr
(
other
.
spec
.
dist_spec
,
k
)
assert
getattr
(
tensor
.
spec
.
dist_spec
,
k
)
==
getattr
(
other
.
spec
.
dist_spec
,
k
)
def
check_element_wise_ops
():
pg
=
_get_default_group
()
t
=
torch
.
rand
(
2
,
2
)
x
=
ColoTensor
(
t
,
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
size
()])))
check_spec_eq
(
x
,
x
.
cuda
())
assert
torch
.
equal
(
x
.
cuda
(),
t
.
cuda
())
check_spec_eq
(
x
,
torch
.
abs
(
x
))
assert
torch
.
equal
(
torch
.
abs
(
x
),
torch
.
abs
(
t
))
check_spec_eq
(
x
,
F
.
sigmoid
(
x
))
assert
torch
.
equal
(
F
.
sigmoid
(
x
),
F
.
sigmoid
(
t
))
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_element_wise_ops
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_element_wise_ops
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
check_all
():
test_layernorm
()
test_element_wise_ops
(
2
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment