Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1190b2c4
Unverified
Commit
1190b2c4
authored
Apr 25, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 25, 2022
Browse files
[tensor] add cross_entrophy_loss (#868)
parent
31078171
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
7 deletions
+48
-7
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+2
-1
colossalai/tensor/_ops/layernorm.py
colossalai/tensor/_ops/layernorm.py
+0
-1
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+5
-1
colossalai/tensor/_ops/loss.py
colossalai/tensor/_ops/loss.py
+29
-0
tests/components_to_test/simple_net.py
tests/components_to_test/simple_net.py
+4
-0
tests/test_tensor/test_net_tp.py
tests/test_tensor/test_net_tp.py
+8
-4
No files found.
colossalai/tensor/_ops/__init__.py
View file @
1190b2c4
...
@@ -2,3 +2,4 @@ from .init import colo_uniform
...
@@ -2,3 +2,4 @@ from .init import colo_uniform
from
.linear
import
colo_linear
from
.linear
import
colo_linear
from
.element_wise
import
colo_mean
from
.element_wise
import
colo_mean
from
.layernorm
import
colo_layernorm
from
.layernorm
import
colo_layernorm
from
.loss
import
colo_cross_entropy
\ No newline at end of file
colossalai/tensor/_ops/layernorm.py
View file @
1190b2c4
from
numpy
import
isin
,
kaiser
import
torch
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
...
...
colossalai/tensor/_ops/linear.py
View file @
1190b2c4
...
@@ -31,7 +31,11 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -31,7 +31,11 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
if
isinstance
(
weight
,
ColoTensor
):
if
isinstance
(
weight
,
ColoTensor
):
if
weight
.
shard_spec
==
None
:
if
weight
.
shard_spec
==
None
:
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
.
torch_tensor
(),
bias
)
if
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
weight
=
weight
.
torch_tensor
()
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
elif
weight
.
shard_spec
==
'1Drow'
:
elif
weight
.
shard_spec
==
'1Drow'
:
# Input:S[1] x Weight:S[0] = Output:P
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
...
...
colossalai/tensor/_ops/loss.py
0 → 100644
View file @
1190b2c4
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
@
colo_op_impl
(
torch
.
nn
.
functional
.
cross_entropy
)
def
colo_cross_entropy
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
arg_num
=
len
(
args
)
if
arg_num
>
0
:
input_tensor
=
args
[
0
]
if
arg_num
>
1
:
target
=
args
[
1
]
if
arg_num
>
2
:
weight
=
args
[
3
]
if
'input'
in
kwargs
:
input_tensor
=
kwargs
[
'input'
]
if
'target'
in
kwargs
:
target
=
kwargs
[
'target'
]
if
'weight'
in
kwargs
:
weight
=
kwargs
[
'weight'
]
if
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
target
,
ColoTensor
):
target
=
target
.
torch_tensor
()
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
cross_entropy
(
input_tensor
,
target
,
weight
))
tests/components_to_test/simple_net.py
View file @
1190b2c4
...
@@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule):
...
@@ -14,11 +14,15 @@ class SimpleNet(CheckpointModule):
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
(
checkpoint
=
checkpoint
)
super
().
__init__
(
checkpoint
=
checkpoint
)
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
ln1
=
nn
.
LayerNorm
(
8
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
self
.
ln2
=
nn
.
LayerNorm
(
4
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
proj1
(
x
)
x
=
self
.
proj1
(
x
)
x
=
self
.
ln1
(
x
)
x
=
self
.
proj2
(
x
)
x
=
self
.
proj2
(
x
)
x
=
self
.
ln2
(
x
)
return
x
return
x
...
...
tests/test_tensor/test_net_tp.py
View file @
1190b2c4
from
cProfile
import
label
from
cProfile
import
label
from
statistics
import
mode
from
statistics
import
mode
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
import
colossalai
import
colossalai
...
@@ -20,21 +21,23 @@ def run_simple_net():
...
@@ -20,21 +21,23 @@ def run_simple_net():
# A simple net with two stacked nn.Linear
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'simple_net'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
():
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
# we set the Specs for weight of each linear.
# we set the Specs for weight of each linear.
model
.
proj1
.
weight
.
set_spec
(
'1Drow'
)
#
model.proj1.weight.set_spec('1Drow')
model
.
proj2
.
weight
.
set_spec
(
'1Drow'
)
#
model.proj2.weight.set_spec('1Drow')
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
output
=
model
(
data
)
output
=
model
(
data
)
print
(
output
)
if
criterion
:
if
criterion
:
loss
=
criterion
(
output
,
label
)
loss
=
criterion
(
output
,
label
)
else
:
else
:
loss
=
output
loss
=
output
print
(
loss
.
torch_tensor
())
loss
.
backward
()
loss
.
backward
()
if
i
>
5
:
if
i
>
5
:
...
@@ -49,6 +52,7 @@ def run_dist(rank, world_size, port):
...
@@ -49,6 +52,7 @@ def run_dist(rank, world_size, port):
run_simple_net
()
run_simple_net
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
@
parameterize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment