Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5400fdc4
Commit
5400fdc4
authored
Sep 15, 2023
by
Tri Dao
Browse files
[CE] Implement CrossEntropyLoss in Triton
parent
56b7fc6e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
370 additions
and
135 deletions
+370
-135
csrc/xentropy/README.md
csrc/xentropy/README.md
+5
-0
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+34
-119
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+293
-0
tests/losses/test_cross_entropy.py
tests/losses/test_cross_entropy.py
+20
-8
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+18
-8
No files found.
csrc/xentropy/README.md
View file @
5400fdc4
...
@@ -7,3 +7,8 @@ It has only been tested on A100s.
...
@@ -7,3 +7,8 @@ It has only been tested on A100s.
```
sh
```
sh
cd
csrc/xentropy
&&
pip
install
.
cd
csrc/xentropy
&&
pip
install
.
```
```
As of 2023-09-15, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[
implementation
](
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py
)
.
See the CrossEntropyLoss
[
module
](
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py
)
for more details.
flash_attn/losses/cross_entropy.py
View file @
5400fdc4
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# Copyright (c) 2023, Tri Dao.
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
,
):
"""
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch
,
vocab_size
=
logits
.
shape
assert
labels
.
shape
==
(
batch
,)
world_size
=
1
if
process_group
is
None
else
torch
.
distributed
.
get_world_size
(
process_group
)
ctx
.
total_classes
=
world_size
*
vocab_size
if
world_size
==
1
:
losses
,
lse
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
labels_local
=
labels
else
:
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
from
flash_attn.ops.triton.cross_entropy
import
cross_entropy_loss
labels_mask
=
(
labels
<
vocab_start_index
)
|
(
labels
>=
vocab_end_index
)
ignored_mask
=
labels
==
ignored_index
labels_local
=
torch
.
where
(
ignored_mask
,
labels
,
labels
-
vocab_start_index
)
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits
,
labels_local
,
smoothing
,
world_size
*
vocab_size
)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
ignored_mask
,
0
)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample
=
torch
.
div
(
labels
,
vocab_size
,
rounding_mode
=
"floor"
)
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)
]
handle_losses
.
wait
()
if
smoothing
==
0.0
:
losses
+=
lse
-
lse_local
else
:
losses
+=
(
1
-
smoothing
)
*
(
lse
-
lse_local
)
+
smoothing
*
(
lse
-
lse_allgather
.
sum
(
dim
=
0
)
)
losses
.
masked_fill_
(
ignored_mask
,
0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels_local
)
ctx
.
smoothing
=
smoothing
ctx
.
ignored_index
=
ignored_index
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
logits
,
lse
,
labels
=
ctx
.
saved_tensors
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
,
ctx
.
total_classes
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
,
None
class
CrossEntropyLoss
(
nn
.
Module
):
class
CrossEntropyLoss
(
nn
.
Module
):
...
@@ -119,30 +12,52 @@ class CrossEntropyLoss(nn.Module):
...
@@ -119,30 +12,52 @@ class CrossEntropyLoss(nn.Module):
ignore_index
=-
100
,
ignore_index
=-
100
,
reduction
=
"mean"
,
reduction
=
"mean"
,
label_smoothing
=
0.0
,
label_smoothing
=
0.0
,
lse_square_scale
=
0.0
,
inplace_backward
=
False
,
inplace_backward
=
False
,
process_group
=
None
,
process_group
=
None
,
):
):
"""
Arguments:
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
"""
super
().
__init__
()
super
().
__init__
()
if
reduction
not
in
[
"mean"
,
"none"
]:
if
reduction
not
in
[
"mean"
,
"none"
,
"sum"
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'
or 'sum'
"
)
self
.
ignore_index
=
ignore_index
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
label_smoothing
=
label_smoothing
self
.
lse_square_scale
=
lse_square_scale
self
.
inplace_backward
=
inplace_backward
self
.
inplace_backward
=
inplace_backward
self
.
process_group
=
process_group
self
.
process_group
=
process_group
def
forward
(
self
,
input
,
target
):
def
forward
(
self
,
input
,
target
):
assert
input
.
is_cuda
and
target
.
is_cuda
"""
# SoftmaxCrossEntropyLoss implicitly casts to float
Arguments:
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input: (batch, vocab_size)
target: (batch,)
Returns:
losses: (batch,) if reduction is 'none', else (1,), dtype float
"""
assert
input
.
is_cuda
and
target
.
is_cuda
,
"Only support CUDA tensors"
loss
=
cross_entropy_loss
(
input
,
input
,
target
,
target
,
self
.
label_smoothing
,
label_smoothing
=
self
.
label_smoothing
,
self
.
ignore_index
,
lse_square_scale
=
self
.
lse_square_scale
,
self
.
inplace_backward
,
ignored_index
=
self
.
ignore_index
,
self
.
process_group
,
inplace_backward
=
self
.
inplace_backward
,
process_group
=
self
.
process_group
,
)
)
if
self
.
reduction
==
"mean"
:
if
self
.
reduction
==
"mean"
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
elif
self
.
reduction
==
"sum"
:
return
loss
.
sum
()
else
:
else
:
return
loss
return
loss
flash_attn/ops/triton/cross_entropy.py
0 → 100644
View file @
5400fdc4
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
,
Optional
,
Union
import
torch
from
einops
import
rearrange
import
triton
import
triton.language
as
tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if
"all_gather_into_tensor"
not
in
dir
(
torch
.
distributed
):
torch
.
distributed
.
all_gather_into_tensor
=
torch
.
distributed
.
_all_gather_base
@
triton
.
heuristics
(
{
"HAS_SMOOTHING"
:
lambda
args
:
args
[
"smoothing"
]
>
0.0
,
}
)
@
triton
.
jit
def
cross_entropy_fwd_kernel
(
loss_ptr
,
# data ptrs
lse_ptr
,
logits_ptr
,
labels_ptr
,
smoothing
,
lse_square_scale
,
ignored_index
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
n_rows
,
logits_row_stride
,
# strides
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_SMOOTHING
:
tl
.
constexpr
,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
max_logits
=
tl
.
max
(
logits
,
0
)
if
HAS_SMOOTHING
:
sum_logits
=
tl
.
sum
(
tl
.
where
(
col_offsets
<
n_cols
,
logits
,
0.0
),
0
)
lse
=
tl
.
log
(
tl
.
sum
(
tl
.
exp
(
logits
-
max_logits
),
0
))
+
max_logits
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
if
label_idx
==
ignored_index
:
loss
=
0.0
else
:
label_idx
-=
class_start_idx
if
label_idx
>=
col_block_idx
*
BLOCK_SIZE
and
label_idx
<
min
(
n_cols
,
(
col_block_idx
+
1
)
*
BLOCK_SIZE
):
logits_label
=
tl
.
load
(
logits_ptr
+
label_idx
)
if
HAS_SMOOTHING
:
loss
=
(
(
lse
if
not
SPLIT
else
0.0
)
-
smoothing
*
sum_logits
/
total_classes
-
(
1
-
smoothing
)
*
logits_label
)
else
:
loss
=
(
lse
if
not
SPLIT
else
0.0
)
-
logits_label
else
:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if
HAS_SMOOTHING
:
loss
=
smoothing
*
((
lse
if
not
SPLIT
else
0.0
)
-
sum_logits
/
total_classes
)
else
:
loss
=
0.0
if
not
SPLIT
:
loss
+=
lse_square_scale
*
lse
*
lse
tl
.
store
(
loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
loss
)
@
triton
.
heuristics
(
{
"HAS_SMOOTHING"
:
lambda
args
:
args
[
"smoothing"
]
>
0.0
,
}
)
@
triton
.
jit
def
cross_entropy_bwd_kernel
(
dlogits_ptr
,
# data ptrs
dloss_ptr
,
logits_ptr
,
lse_ptr
,
labels_ptr
,
smoothing
,
lse_square_scale
,
ignored_index
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
logits_row_stride
,
# strides
dlogits_row_stride
,
dloss_row_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_SMOOTHING
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
if
label_idx
!=
ignored_index
:
dloss
=
tl
.
load
(
dloss_ptr
+
row_idx
*
dloss_row_stride
)
else
:
dloss
=
0.0
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
lse
=
tl
.
load
(
lse_ptr
+
row_idx
)
probs
=
tl
.
exp
(
logits
-
lse
)
probs
+=
2.0
*
lse_square_scale
*
lse
*
probs
label_idx
-=
class_start_idx
if
HAS_SMOOTHING
:
smooth_positive
=
1.0
-
smoothing
smooth_negative
=
smoothing
/
total_classes
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
(
1
-
smoothing
),
probs
)
-
smooth_negative
else
:
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
1.0
,
probs
)
tl
.
store
(
dlogits_ptr
+
col_offsets
,
dloss
*
probs
,
mask
=
col_offsets
<
n_cols
)
class
CrossEntropyLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
,
lse_square_scale
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
,
):
n_rows
,
n_cols
=
logits
.
shape
assert
labels
.
shape
==
(
n_rows
,)
world_size
=
1
if
process_group
is
None
else
torch
.
distributed
.
get_world_size
(
process_group
)
total_classes
=
world_size
*
n_cols
rank
=
0
if
process_group
is
None
else
torch
.
distributed
.
get_rank
(
process_group
)
class_start_idx
=
rank
*
n_cols
if
logits
.
stride
(
-
1
)
!=
1
:
logits
=
logits
.
contiguous
()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE
=
64
*
1024
BLOCK_SIZE
=
min
(
triton
.
next_power_of_2
(
n_cols
),
MAX_BLOCK_SIZE
)
num_warps
=
(
4
if
BLOCK_SIZE
<
2048
else
(
8
if
BLOCK_SIZE
<
8192
else
(
16
if
BLOCK_SIZE
<
128
*
1024
else
32
))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split
=
world_size
>
1
or
n_cols
>
MAX_BLOCK_SIZE
n_splits
=
(
n_cols
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
loss_shape
=
(
n_splits
,
n_rows
)
if
n_splits
>
1
else
(
n_rows
,)
losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
lse
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_fwd_kernel
[(
n_rows
,
n_splits
)](
losses
,
# data ptrs
lse
,
logits
,
labels
,
smoothing
,
lse_square_scale
,
ignored_index
,
total_classes
,
class_start_idx
,
n_cols
,
# shapes
n_rows
,
logits
.
stride
(
0
),
# strides
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
SPLIT
=
split
,
)
if
split
:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if
world_size
>
1
:
lse_allgather
=
torch
.
empty
(
world_size
,
n_rows
,
dtype
=
lse
.
dtype
,
device
=
lse
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse
,
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
lse
=
torch
.
logsumexp
(
lse_allgather
,
dim
=
0
)
handle_losses
.
wait
()
else
:
lse
=
torch
.
logsumexp
(
lse
,
dim
=
0
)
losses
=
losses
.
sum
(
dim
=
0
)
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses
+=
lse
if
lse_square_scale
!=
0.0
:
losses
+=
lse_square_scale
*
lse
.
square
()
losses
.
masked_fill_
(
labels
==
ignored_index
,
0.0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
smoothing
=
smoothing
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
ignored_index
=
ignored_index
ctx
.
total_classes
=
total_classes
ctx
.
class_start_idx
=
class_start_idx
ctx
.
inplace_backward
=
inplace_backward
return
losses
@
staticmethod
def
backward
(
ctx
,
grad_losses
):
logits
,
lse
,
labels
=
ctx
.
saved_tensors
dlogits
=
logits
if
ctx
.
inplace_backward
else
torch
.
empty_like
(
logits
)
n_rows
,
n_cols
=
logits
.
shape
BLOCK_SIZE
=
min
(
triton
.
next_power_of_2
(
n_cols
),
4
*
1024
)
num_warps
=
4
if
BLOCK_SIZE
<
2048
else
(
8
if
BLOCK_SIZE
<
8192
else
16
)
grid
=
lambda
META
:
(
n_rows
,
triton
.
cdiv
(
n_cols
,
META
[
"BLOCK_SIZE"
]))
# noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_bwd_kernel
[
grid
](
dlogits
,
# data ptrs
grad_losses
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
lse_square_scale
,
ctx
.
ignored_index
,
ctx
.
total_classes
,
ctx
.
class_start_idx
,
n_cols
,
# shapes
logits
.
stride
(
0
),
# strides
dlogits
.
stride
(
0
),
grad_losses
.
stride
(
0
),
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
)
return
dlogits
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
cross_entropy_loss
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
label_smoothing
:
float
=
0.0
,
lse_square_scale
:
float
=
0.0
,
ignored_index
=-
100
,
inplace_backward
:
bool
=
False
,
process_group
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
"""
return
CrossEntropyLoss
.
apply
(
logits
,
labels
,
label_smoothing
,
lse_square_scale
,
ignored_index
,
inplace_backward
,
process_group
,
)
tests/losses/test_cross_entropy.py
View file @
5400fdc4
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
Apex
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
...
@@ -12,12 +12,16 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
...
@@ -12,12 +12,16 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
"dtype"
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
)
# @pytest.mark.parametrize(
'
dtype
'
, [torch.float16])
# @pytest.mark.parametrize(
"
dtype
"
, [torch.float16])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace_backward', [False])
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50257
])
# @pytest.mark.parametrize("smoothing", [0.0])
def
test_cross_entropy_loss_apex
(
vocab_size
,
smoothing
,
inplace_backward
,
dtype
):
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50257
,
128
*
1024
])
# test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def
test_cross_entropy_loss
(
vocab_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
):
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
# set seed
# set seed
...
@@ -29,12 +33,20 @@ def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype)
...
@@ -29,12 +33,20 @@ def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype)
)
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
*
seqlen
,),
dtype
=
torch
.
long
,
device
=
device
)
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
if
batch_size
*
seqlen
>
10
:
y
[
torch
.
randperm
(
batch_size
*
seqlen
)[:
10
]]
=
-
100
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
label_smoothing
=
smoothing
)
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
label_smoothing
=
smoothing
)
model
=
CrossEntropyLossApex
(
label_smoothing
=
smoothing
,
inplace_backward
=
inplace_backward
)
model
=
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
lse_square_scale
=
lse_square_scale
,
inplace_backward
=
inplace_backward
,
)
out
=
model
(
x
,
y
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
(),
dim
=-
1
)
out_pt
+=
lse_square_scale
*
(
lse_pt
[
y
!=
-
100
]
**
2
).
mean
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
...
...
tests/losses/test_cross_entropy_parallel.py
View file @
5400fdc4
# Run test with:
# Run test with:
# torchrun --no_python --nproc_per_node=
8
pytest -q -s tests/losses/test_cross_entropy_parallel.py
# torchrun --no_python --nproc_per_node=
4
pytest -q -s tests/losses/test_cross_entropy_parallel.py
import
math
import
math
...
@@ -15,15 +15,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
...
@@ -15,15 +15,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
"dtype"
,
[
torch
.
float16
,
torch
.
float32
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
)
# @pytest.mark.parametrize(
'
dtype
'
, [torch.float16])
# @pytest.mark.parametrize(
"
dtype
"
, [torch.float16])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace_backward', [False])
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
# @pytest.mark.parametrize('smoothing', [0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50264
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50264
,
128
*
1024
])
# test vocab larger than 64k for split
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
def
test_cross_entropy_loss_parallel
(
vocab_size
,
world_size
,
smoothing
,
inplace_backward
,
dtype
):
# @pytest.mark.parametrize("world_size", [2])
def
test_cross_entropy_loss_parallel
(
vocab_size
,
world_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
(
rtol
,
atol
=
(
(
1e-5
,
1e-6
)
(
1e-5
,
1e-6
)
...
@@ -56,11 +61,16 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
...
@@ -56,11 +61,16 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
model
=
CrossEntropyLoss
(
model
=
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
label_smoothing
=
smoothing
,
reduction
=
"none"
,
reduction
=
"none"
,
lse_square_scale
=
lse_square_scale
,
inplace_backward
=
inplace_backward
,
inplace_backward
=
inplace_backward
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
)
)
out
=
model
(
x
,
y
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
(),
dim
=-
1
)
out_pt
+=
lse_square_scale
*
lse_pt
.
square
()
out_pt
.
masked_fill_
(
y
==
-
100
,
0.0
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment