Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
126ba573
Unverified
Commit
126ba573
authored
Apr 25, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 25, 2022
Browse files
[Tensor] add layer norm Op (#852)
parent
a82da26f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
79 additions
and
8 deletions
+79
-8
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+2
-1
colossalai/tensor/_ops/element_wise.py
colossalai/tensor/_ops/element_wise.py
+5
-3
colossalai/tensor/_ops/layernorm.py
colossalai/tensor/_ops/layernorm.py
+38
-0
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+4
-0
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+30
-4
No files found.
colossalai/tensor/_ops/__init__.py
View file @
126ba573
from
.init
import
colo_uniform
from
.linear
import
colo_linear
from
.element_wise
import
colo_mean
from
.layernorm
import
colo_layernorm
\ No newline at end of file
colossalai/tensor/_ops/element_wise.py
View file @
126ba573
...
...
@@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor
@
colo_op_impl
(
torch
.
mean
)
def
colo_mean
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
stateful_tensor
=
args
[
0
]
return
torch
.
mean
(
stateful_tensor
.
torch_tensor
())
input_t
=
args
[
0
]
if
isinstance
(
input_t
,
ColoTensor
):
input_t
=
input_t
.
torch_tensor
()
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
mean
(
input_t
))
def
register_elementwise_op
(
op
):
...
...
@@ -22,7 +24,7 @@ def register_elementwise_op(op):
# Validate types
if
not
isinstance
(
input_tensor
,
ColoTensor
):
raise
TypeError
(
"input needs to be a ColoTensor"
)
return
op
(
input_tensor
.
torch_tensor
())
return
ColoTensor
.
init_from_torch_tensor
(
op
(
input_tensor
.
torch_tensor
())
)
register_elementwise_op
(
torch
.
nn
.
functional
.
gelu
)
...
...
colossalai/tensor/_ops/layernorm.py
0 → 100644
View file @
126ba573
from
numpy
import
isin
,
kaiser
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
@
colo_op_impl
(
torch
.
nn
.
functional
.
layer_norm
)
def
colo_layernorm
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
arg_num
=
len
(
args
)
if
arg_num
>
0
:
input_tensor
=
args
[
0
]
if
arg_num
>
1
:
normalized_shape
=
args
[
1
]
if
arg_num
>
2
:
weight
=
args
[
3
]
if
arg_num
>
3
:
bias
=
args
[
4
]
if
arg_num
>
4
:
eps
=
args
[
5
]
if
'input'
in
kwargs
:
input_tensor
=
kwargs
[
'input'
]
if
'weight'
in
kwargs
:
weight
=
kwargs
[
'weight'
]
if
'bias'
in
kwargs
:
bias
=
kwargs
[
'bias'
]
if
'eps'
in
kwargs
:
eps
=
kwargs
[
'eps'
]
if
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
if
isinstance
(
bias
,
ColoTensor
):
bias
=
bias
.
torch_tensor
()
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
,
bias
,
eps
))
colossalai/tensor/colo_tensor.py
View file @
126ba573
...
...
@@ -8,6 +8,7 @@ from colossalai.context import ParallelMode
from
colossalai.nn.layer.utils
import
divide
from
colossalai.utils.cuda
import
get_current_device
class
ColoTensor
(
object
):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
...
...
@@ -145,3 +146,6 @@ class ColoTensor(object):
kwargs
=
{
k
:
v
.
torch_tensor
()
if
isinstance
(
v
,
ColoTensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
return
func
(
*
args
,
**
kwargs
)
def
backward
(
self
,
retain_graph
:
bool
=
False
):
self
.
_torch_tensor
.
backward
(
retain_graph
=
retain_graph
)
tests/test_tensor/test_op.py
View file @
126ba573
from
numpy
import
allclose
,
require
from
numpy
import
allclose
import
torch
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
from
colossalai.utils
import
get_current_device
def
test_layernorm
():
ln_op
=
torch
.
nn
.
LayerNorm
(
2
,
3
,
device
=
get_current_device
())
ln_op_colo
=
deepcopy
(
ln_op
)
input_t
=
torch
.
randn
(
3
,
2
,
device
=
get_current_device
())
input_t_colo
=
ColoTensor
.
init_from_torch_tensor
(
tensor
=
input_t
.
clone
().
detach
())
# prepare colossalai LN
delattr
(
ln_op_colo
,
'weight'
)
weight_clone
=
ln_op
.
weight
.
clone
().
detach
()
weight_clone
.
requires_grad
=
True
setattr
(
ln_op_colo
,
'weight'
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
weight_clone
))
output
=
ln_op
(
input_t
)
output_colo
=
ln_op_colo
(
input_t_colo
)
assert
allclose
(
output_colo
.
torch_tensor
().
detach
().
cpu
(),
output
.
detach
().
cpu
())
torch
.
mean
(
output
).
backward
()
torch
.
mean
(
output_colo
).
backward
()
assert
allclose
(
ln_op
.
weight
.
grad
.
cpu
(),
ln_op_colo
.
weight
.
torch_tensor
().
grad
.
cpu
())
def
test_linear
():
...
...
@@ -50,8 +75,8 @@ def test_element_wise():
t_ref
=
torch
.
randn
(
3
,
5
)
t
=
ColoTensor
.
init_from_torch_tensor
(
t_ref
.
clone
())
assert
torch
.
mean
(
t
)
==
torch
.
mean
(
t_ref
)
assert
allclose
(
torch
.
nn
.
functional
.
gelu
(
t
),
torch
.
nn
.
functional
.
gelu
(
t_ref
))
assert
allclose
(
torch
.
nn
.
functional
.
relu
(
t
),
torch
.
nn
.
functional
.
relu
(
t_ref
))
assert
allclose
(
torch
.
nn
.
functional
.
gelu
(
t
)
.
torch_tensor
()
,
torch
.
nn
.
functional
.
gelu
(
t_ref
))
assert
allclose
(
torch
.
nn
.
functional
.
relu
(
t
)
.
torch_tensor
()
,
torch
.
nn
.
functional
.
relu
(
t_ref
))
# Test a function not wrapped by
...
...
@@ -76,4 +101,5 @@ def check_all():
if
__name__
==
'__main__'
:
test_lazy_init_tensor
()
# test_lazy_init_ptensor()
test_layernorm
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment