Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
75d22191
Unverified
Commit
75d22191
authored
May 07, 2022
by
Ziyue Jiang
Committed by
GitHub
May 07, 2022
Browse files
[Tensor] add 1d vocab loss (#918)
* add 1d vocab loss * polish
parent
dfaff4e2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
13 deletions
+37
-13
colossalai/tensor/_ops/loss.py
colossalai/tensor/_ops/loss.py
+20
-8
tests/components_to_test/simple_net.py
tests/components_to_test/simple_net.py
+4
-2
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+13
-3
No files found.
colossalai/tensor/_ops/loss.py
View file @
75d22191
from
colossalai.tensor.spec
import
ShardPattern
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
from
colossalai.nn.loss.loss_1d
import
VocabParallelCrossEntropyLoss1D
@
colo_op_impl
(
torch
.
nn
.
functional
.
cross_entropy
)
def
colo_cross_entropy
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
...
...
@@ -12,18 +13,29 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
if
arg_num
>
1
:
target
=
args
[
1
]
if
arg_num
>
2
:
weight
=
args
[
3
]
weight
=
args
[
2
]
if
'input'
in
kwargs
:
input_tensor
=
kwargs
[
'input'
]
input_tensor
=
kwargs
.
pop
(
'input'
)
if
'target'
in
kwargs
:
target
=
kwargs
[
'target'
]
target
=
kwargs
.
pop
(
'target'
)
if
'weight'
in
kwargs
:
weight
=
kwargs
[
'weight'
]
weight
=
kwargs
.
pop
(
'weight'
)
if
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
input_tensor
.
torch_tensor
()
if
not
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
ColoTensor
.
init_from_
torch_tensor
(
input_tensor
)
if
isinstance
(
target
,
ColoTensor
):
target
=
target
.
torch_tensor
()
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
cross_entropy
(
input_tensor
,
target
,
weight
))
if
input_tensor
.
is_gathered
():
# Input is gathered
# TODO(jzy) Shall we make the result of loss function a ColoTensor?
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
cross_entropy
(
input_tensor
.
torch_tensor
(),
target
,
weight
))
elif
input_tensor
.
has_spec
()
and
input_tensor
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
if
input_tensor
.
shard_pattern
==
ShardPattern
.
Col
:
return
ColoTensor
.
init_from_torch_tensor
(
VocabParallelCrossEntropyLoss1D
()(
input_tensor
.
torch_tensor
(),
target
))
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
tests/components_to_test/simple_net.py
View file @
75d22191
...
...
@@ -17,6 +17,7 @@ class SimpleNet(CheckpointModule):
self
.
ln1
=
nn
.
LayerNorm
(
8
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
self
.
ln2
=
nn
.
LayerNorm
(
4
)
self
.
classifier
=
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
embed
(
x
)
...
...
@@ -24,6 +25,7 @@ class SimpleNet(CheckpointModule):
x
=
self
.
ln1
(
x
)
x
=
self
.
proj2
(
x
)
x
=
self
.
ln2
(
x
)
x
=
self
.
classifier
(
x
)
return
x
...
...
@@ -31,8 +33,8 @@ class SimpleNet(CheckpointModule):
class
DummyDataLoader
(
DummyDataGenerator
):
def
generate
(
self
):
data
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
16
,
20
),
device
=
get_current_device
())
label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
16
,
4
),
device
=
get_current_device
())
data
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
16
,),
device
=
get_current_device
())
label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
16
,),
device
=
get_current_device
())
return
data
,
label
...
...
tests/test_tensor/test_model.py
View file @
75d22191
...
...
@@ -144,10 +144,18 @@ def run_1d_hybrid_tp(model_name):
parallel_action_list_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
,
]
spec_col
=
TensorSpec
(
parallel_action_list_col
)
parallel_action_list_classifier_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
,
gather_out
=
False
),
]
spec_classifier_col
=
TensorSpec
(
parallel_action_list_classifier_col
)
parallel_action_list_embedding_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
...
...
@@ -158,12 +166,14 @@ def run_1d_hybrid_tp(model_name):
for
name
,
p
in
model
.
colo_named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'embed'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_embedding_col
)
if
'proj1'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
):
p
.
set_spec
(
spec_col
)
if
'proj2'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_row
)
if
'
embed
'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_
embedding
_col
)
if
'
classifier
'
in
name
and
(
'weight'
in
name
or
'bias'
in
name
)
:
p
.
set_spec
(
spec_
classifier
_col
)
set_seed
(
1
)
if
rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment