Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
71befc19
Commit
71befc19
authored
Dec 31, 2022
by
Tri Dao
Browse files
[Loss] Use flash_attn.losses.cross_entropy.CrossEntropyLoss
parent
cadfa396
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
3 additions
and
132 deletions
+3
-132
training/configs/experiment/owt/base.yaml
training/configs/experiment/owt/base.yaml
+1
-1
training/configs/experiment/pile/base.yaml
training/configs/experiment/pile/base.yaml
+1
-1
training/src/losses/cross_entropy.py
training/src/losses/cross_entropy.py
+0
-129
training/src/metrics/perplexity.py
training/src/metrics/perplexity.py
+1
-1
No files found.
training/configs/experiment/owt/base.yaml
View file @
71befc19
...
@@ -54,7 +54,7 @@ train:
...
@@ -54,7 +54,7 @@ train:
loss_fn
:
loss_fn
:
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
_target_
:
src
.losses.cross_entropy.CrossEntropyLoss
_target_
:
flash_attn
.losses.cross_entropy.CrossEntropyLoss
inplace_backward
:
True
# to save memory
inplace_backward
:
True
# to save memory
eval
:
eval
:
...
...
training/configs/experiment/pile/base.yaml
View file @
71befc19
...
@@ -54,7 +54,7 @@ train:
...
@@ -54,7 +54,7 @@ train:
loss_fn
:
loss_fn
:
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
_target_
:
src
.losses.cross_entropy.CrossEntropyLoss
_target_
:
flash_attn
.losses.cross_entropy.CrossEntropyLoss
inplace_backward
:
True
# to save memory
inplace_backward
:
True
# to save memory
eval
:
eval
:
...
...
training/src/losses/cross_entropy.py
deleted
100644 → 0
View file @
cadfa396
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import
torch
import
torch.nn
as
nn
import
xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
):
"""
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch
,
vocab_size
=
logits
.
shape
assert
labels
.
shape
==
(
batch
,)
world_size
=
1
if
process_group
is
None
else
torch
.
distributed
.
get_world_size
(
process_group
)
ctx
.
total_classes
=
world_size
*
vocab_size
if
world_size
==
1
:
losses
,
lse
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
labels_local
=
labels
else
:
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask
=
(
labels
<
vocab_start_index
)
|
(
labels
>=
vocab_end_index
)
ignored_mask
=
labels
==
ignored_index
labels_local
=
torch
.
where
(
ignored_mask
,
labels
,
labels
-
vocab_start_index
)
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits
,
labels_local
,
smoothing
,
world_size
*
vocab_size
)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
ignored_mask
,
0
)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample
=
torch
.
div
(
labels
,
vocab_size
,
rounding_mode
=
'floor'
)
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
handle_losses
.
wait
()
if
smoothing
==
0.0
:
losses
+=
lse
-
lse_local
else
:
losses
+=
((
1
-
smoothing
)
*
(
lse
-
lse_local
)
+
smoothing
*
(
lse
-
lse_allgather
.
sum
(
dim
=
0
)))
losses
.
masked_fill_
(
ignored_mask
,
0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels_local
)
ctx
.
smoothing
=
smoothing
ctx
.
ignored_index
=
ignored_index
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
logits
,
lse
,
labels
=
ctx
.
saved_tensors
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
,
ctx
.
total_classes
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
,
None
class
CrossEntropyLoss
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
,
process_group
=
None
):
super
().
__init__
()
if
reduction
not
in
[
'mean'
,
'none'
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
inplace_backward
=
inplace_backward
self
.
process_group
=
process_group
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
,
self
.
process_group
)
if
self
.
reduction
==
'mean'
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
return
loss
training/src/metrics/perplexity.py
View file @
71befc19
...
@@ -11,7 +11,7 @@ from torch import Tensor
...
@@ -11,7 +11,7 @@ from torch import Tensor
from
torchmetrics
import
Metric
from
torchmetrics
import
Metric
try
:
try
:
from
src
.losses.cross_entropy
import
CrossEntropyLoss
from
flash_attn
.losses.cross_entropy
import
CrossEntropyLoss
except
ImportError
:
except
ImportError
:
CrossEntropyLoss
=
torch
.
nn
.
CrossEntropyLoss
CrossEntropyLoss
=
torch
.
nn
.
CrossEntropyLoss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment