Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
42ab36b7
Unverified
Commit
42ab36b7
authored
Jul 07, 2022
by
HELSON
Committed by
GitHub
Jul 07, 2022
Browse files
[tensor] add unitest for colo_tensor 1DTP cross_entropy (#1230)
parent
04537bf8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
19 deletions
+73
-19
colossalai/nn/_ops/loss.py
colossalai/nn/_ops/loss.py
+9
-3
colossalai/nn/loss/loss_1d.py
colossalai/nn/loss/loss_1d.py
+12
-16
tests/test_tensor/test_loss_func.py
tests/test_tensor/test_loss_func.py
+52
-0
No files found.
colossalai/nn/_ops/loss.py
View file @
42ab36b7
...
@@ -23,6 +23,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
...
@@ -23,6 +23,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
input_tensor
=
convert_to_colo_tensor
(
input_tensor
,
pg
)
input_tensor
=
convert_to_colo_tensor
(
input_tensor
,
pg
)
if
input_tensor
.
is_replicate
():
# Input is gathered
if
input_tensor
.
is_replicate
():
# Input is gathered
assert
target
.
is_replicate
()
and
weight
.
is_replicate
(),
\
"Target tensor and weight tensor both should be complete"
output
=
F
.
cross_entropy
(
input_tensor
,
output
=
F
.
cross_entropy
(
input_tensor
,
target
,
target
,
weight
=
weight
,
weight
=
weight
,
...
@@ -31,11 +33,15 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
...
@@ -31,11 +33,15 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduce
=
reduce
,
reduce
=
reduce
,
reduction
=
reduction
,
reduction
=
reduction
,
label_smoothing
=
label_smoothing
)
label_smoothing
=
label_smoothing
)
return
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
))
.
to_replicate
()
return
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
))
elif
input_tensor
.
has_compute_spec
():
# Single Model Parallel Applied
elif
input_tensor
.
has_compute_spec
():
# Single Model Parallel Applied
if
input_tensor
.
is_shard_1dcol
():
if
input_tensor
.
is_shard_1dcol
():
output
=
VocabParallelCrossEntropyLoss1D
()(
input_tensor
,
target
)
assert
weight
is
None
,
"Current TP cross entropy loss function doesn't support passing weight tensor in"
return
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
)).
to_replicate
()
assert
target
.
is_replicate
(),
"Target tensor should be complete in TP cross entropy loss function"
output
=
VocabParallelCrossEntropyLoss1D
()(
input_tensor
,
target
,
process_group
=
input_tensor
.
process_group
.
tp_process_group
())
return
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
))
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
...
...
colossalai/nn/loss/loss_1d.py
View file @
42ab36b7
import
torch
import
torch
import
torch.distributed
as
dist
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
LOSSES
from
colossalai.registry
import
LOSSES
...
@@ -10,19 +11,19 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
...
@@ -10,19 +11,19 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
vocab_parallel_logits
,
targets
):
def
forward
(
ctx
,
vocab_parallel_logits
,
targets
,
process_group
):
if
process_group
is
None
:
process_group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
)
# Maximum value along vocab dimension across all GPUs.
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
process_group
)
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Subtract the maximum value.
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indecies
# Get the partition's vocab indecies
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
gpc
.
get_
local_rank
(
ParallelMode
.
PARALLEL_1D
)
rank
=
dist
.
get_
rank
(
process_group
)
vocab_start_index
=
partition_vocab_size
*
rank
vocab_start_index
=
partition_vocab_size
*
rank
vocab_end_index
=
vocab_start_index
+
partition_vocab_size
vocab_end_index
=
vocab_start_index
+
partition_vocab_size
...
@@ -42,17 +43,12 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
...
@@ -42,17 +43,12 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
predicted_logits
=
predicted_logits_1d
.
view_as
(
targets
)
predicted_logits
=
predicted_logits_1d
.
view_as
(
targets
)
predicted_logits
[
target_mask
]
=
0.0
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
)
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Sum of exponential of logits along vocab dimension across all GPUs.
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
exp_logits
=
torch
.
exp
(
vocab_parallel_logits
)
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
)
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Loss = log(sum(exp(logits))) - predicted-logit.
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
@@ -81,7 +77,7 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
...
@@ -81,7 +77,7 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function):
# Finally elementwise multiplication with the output gradients.
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
grad_output
.
unsqueeze
(
dim
=-
1
))
grad_input
.
mul_
(
grad_output
.
unsqueeze
(
dim
=-
1
))
return
grad_input
,
None
return
grad_input
,
None
,
None
@
LOSSES
.
register_module
@
LOSSES
.
register_module
...
@@ -96,14 +92,14 @@ class VocabParallelCrossEntropyLoss1D(_Loss):
...
@@ -96,14 +92,14 @@ class VocabParallelCrossEntropyLoss1D(_Loss):
super
().
__init__
()
super
().
__init__
()
self
.
reduction_mean
=
reduction
self
.
reduction_mean
=
reduction
def
forward
(
self
,
logits
,
targets
):
def
forward
(
self
,
logits
,
targets
,
process_group
=
None
):
"""Calculate loss between logits and targets.
"""Calculate loss between logits and targets.
Args:
Args:
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
"""
"""
loss
=
_VocabParallelCrossEntropy1D
.
apply
(
logits
,
targets
)
loss
=
_VocabParallelCrossEntropy1D
.
apply
(
logits
,
targets
,
process_group
)
if
self
.
reduction_mean
:
if
self
.
reduction_mean
:
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
return
loss
return
loss
tests/test_tensor/test_loss_func.py
0 → 100644
View file @
42ab36b7
import
torch
import
pytest
import
colossalai
import
torch.nn.functional
as
F
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
,
ColoTensorSpec
from
colossalai.utils
import
get_current_device
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
ComputeSpec
,
ComputePattern
def
check_cross_entropy
():
input_t
=
torch
.
randn
(
4
,
4
,
device
=
get_current_device
(),
requires_grad
=
True
)
input_ct
=
torch
.
randn
(
4
,
4
,
device
=
get_current_device
(),
requires_grad
=
True
)
with
torch
.
no_grad
():
input_ct
.
copy_
(
input_t
)
target
=
torch
.
randint
(
4
,
(
4
,),
dtype
=
torch
.
int64
,
device
=
get_current_device
())
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
input_t_colo
=
ColoTensor
.
from_torch_tensor
(
tensor
=
input_ct
,
spec
=
ColoTensorSpec
(
pg
))
input_shard
=
input_t_colo
.
convert_to_dist_spec
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]))
input_shard
.
set_tensor_spec
(
dist_spec
=
None
,
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
F
.
cross_entropy
(
input_t
,
target
)
output_colo
=
F
.
cross_entropy
(
input_shard
,
target
)
assert
torch
.
allclose
(
output_colo
,
output
)
output
.
backward
()
output_colo
.
backward
()
assert
torch
.
allclose
(
input_t
.
grad
,
input_ct
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_cross_entropy
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
rerun_if_address_is_in_use
()
def
test_loss_func
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_loss_func
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment