Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d8aacc51
"torchvision/vscode:/vscode.git/clone" did not exist on "767b23ea361c944870c80057f274c15a8475e204"
Unverified
Commit
d8aacc51
authored
Jan 21, 2024
by
Curtis "Fjord" Hawthorne
Committed by
GitHub
Jan 21, 2024
Browse files
return z_loss (#768)
parent
43ceab63
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
14 deletions
+57
-14
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+22
-4
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+24
-7
tests/losses/test_cross_entropy.py
tests/losses/test_cross_entropy.py
+11
-3
No files found.
flash_attn/losses/cross_entropy.py
View file @
d8aacc51
...
@@ -16,6 +16,7 @@ class CrossEntropyLoss(nn.Module):
...
@@ -16,6 +16,7 @@ class CrossEntropyLoss(nn.Module):
lse_square_scale
=
0.0
,
lse_square_scale
=
0.0
,
inplace_backward
=
False
,
inplace_backward
=
False
,
process_group
=
None
,
process_group
=
None
,
return_z_loss
=
False
,
):
):
"""
"""
Arguments:
Arguments:
...
@@ -26,7 +27,10 @@ class CrossEntropyLoss(nn.Module):
...
@@ -26,7 +27,10 @@ class CrossEntropyLoss(nn.Module):
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
one part of the vocab. The loss will be aggregated across processes.
return_z_loss: bool. If True, we return the component of the loss contributed by
the lse_square_scale value. This value is only for logging and does not support
backprop.
"""
"""
super
().
__init__
()
super
().
__init__
()
if
reduction
not
in
[
"mean"
,
"none"
,
"sum"
]:
if
reduction
not
in
[
"mean"
,
"none"
,
"sum"
]:
...
@@ -38,6 +42,7 @@ class CrossEntropyLoss(nn.Module):
...
@@ -38,6 +42,7 @@ class CrossEntropyLoss(nn.Module):
self
.
lse_square_scale
=
lse_square_scale
self
.
lse_square_scale
=
lse_square_scale
self
.
inplace_backward
=
inplace_backward
self
.
inplace_backward
=
inplace_backward
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
return_z_loss
=
return_z_loss
def
forward
(
self
,
input
,
target
):
def
forward
(
self
,
input
,
target
):
"""
"""
...
@@ -46,9 +51,10 @@ class CrossEntropyLoss(nn.Module):
...
@@ -46,9 +51,10 @@ class CrossEntropyLoss(nn.Module):
target: (batch,)
target: (batch,)
Returns:
Returns:
losses: (batch,) if reduction is 'none', else (1,), dtype float
losses: (batch,) if reduction is 'none', else (1,), dtype float
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
"""
"""
assert
input
.
is_cuda
and
target
.
is_cuda
,
"Only support CUDA tensors"
assert
input
.
is_cuda
and
target
.
is_cuda
,
"Only support CUDA tensors"
loss
=
cross_entropy_loss
(
loss
,
z_loss
=
cross_entropy_loss
(
input
,
input
,
target
,
target
,
label_smoothing
=
self
.
label_smoothing
,
label_smoothing
=
self
.
label_smoothing
,
...
@@ -59,8 +65,20 @@ class CrossEntropyLoss(nn.Module):
...
@@ -59,8 +65,20 @@ class CrossEntropyLoss(nn.Module):
process_group
=
self
.
process_group
,
process_group
=
self
.
process_group
,
)
)
if
self
.
reduction
==
"mean"
:
if
self
.
reduction
==
"mean"
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
loss
=
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
elif
self
.
reduction
==
"sum"
:
elif
self
.
reduction
==
"sum"
:
return
loss
.
sum
()
loss
=
loss
.
sum
()
else
:
else
:
loss
=
loss
if
not
self
.
return_z_loss
:
return
loss
return
loss
if
self
.
reduction
==
"mean"
:
z_loss
=
z_loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
elif
self
.
reduction
==
"sum"
:
z_loss
=
z_loss
.
sum
()
else
:
z_loss
=
z_loss
return
loss
,
z_loss
flash_attn/ops/triton/cross_entropy.py
View file @
d8aacc51
...
@@ -26,6 +26,7 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
...
@@ -26,6 +26,7 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
def
cross_entropy_fwd_kernel
(
def
cross_entropy_fwd_kernel
(
loss_ptr
,
# data ptrs
loss_ptr
,
# data ptrs
lse_ptr
,
lse_ptr
,
z_loss_ptr
,
logits_ptr
,
logits_ptr
,
labels_ptr
,
labels_ptr
,
smoothing
,
smoothing
,
...
@@ -57,6 +58,7 @@ def cross_entropy_fwd_kernel(
...
@@ -57,6 +58,7 @@ def cross_entropy_fwd_kernel(
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
if
label_idx
==
ignored_index
:
if
label_idx
==
ignored_index
:
loss
=
0.0
loss
=
0.0
z_loss
=
0.0
else
:
else
:
label_idx
-=
class_start_idx
label_idx
-=
class_start_idx
if
label_idx
>=
col_block_idx
*
BLOCK_SIZE
and
label_idx
<
min
(
if
label_idx
>=
col_block_idx
*
BLOCK_SIZE
and
label_idx
<
min
(
...
@@ -78,8 +80,13 @@ def cross_entropy_fwd_kernel(
...
@@ -78,8 +80,13 @@ def cross_entropy_fwd_kernel(
else
:
else
:
loss
=
0.0
loss
=
0.0
if
not
SPLIT
:
if
not
SPLIT
:
loss
+=
lse_square_scale
*
lse
*
lse
z_loss
=
lse_square_scale
*
lse
*
lse
loss
+=
z_loss
else
:
z_loss
=
0.0
tl
.
store
(
loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
loss
)
tl
.
store
(
loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
loss
)
if
not
SPLIT
:
tl
.
store
(
z_loss_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
z_loss
)
@
triton
.
heuristics
(
@
triton
.
heuristics
(
...
@@ -172,12 +179,14 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -172,12 +179,14 @@ class CrossEntropyLoss(torch.autograd.Function):
loss_shape
=
(
n_splits
,
n_rows
)
if
n_splits
>
1
else
(
n_rows
,)
loss_shape
=
(
n_splits
,
n_rows
)
if
n_splits
>
1
else
(
n_rows
,)
losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
lse
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
lse
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
z_losses
=
torch
.
empty
(
*
loss_shape
,
dtype
=
torch
.
float
,
device
=
logits
.
device
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
with
torch
.
cuda
.
device
(
logits
.
device
.
index
):
cross_entropy_fwd_kernel
[(
n_rows
,
n_splits
)](
cross_entropy_fwd_kernel
[(
n_rows
,
n_splits
)](
losses
,
# data ptrs
losses
,
# data ptrs
lse
,
lse
,
z_losses
,
logits
,
logits
,
labels
,
labels
,
smoothing
,
smoothing
,
...
@@ -219,10 +228,15 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -219,10 +228,15 @@ class CrossEntropyLoss(torch.autograd.Function):
# Again, we just have to add the (global) lse.
# Again, we just have to add the (global) lse.
losses
+=
lse
losses
+=
lse
if
lse_square_scale
!=
0.0
:
if
lse_square_scale
!=
0.0
:
losses
+=
lse_square_scale
*
lse
.
square
()
z_losses
=
lse_square_scale
*
lse
.
square
()
z_losses
.
masked_fill_
(
labels
==
ignored_index
,
0.0
)
losses
+=
z_losses
else
:
z_losses
=
torch
.
zeros_like
(
losses
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0.0
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0.0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
mark_non_differentiable
(
z_losses
)
ctx
.
smoothing
=
smoothing
ctx
.
smoothing
=
smoothing
ctx
.
logit_scale
=
logit_scale
ctx
.
logit_scale
=
logit_scale
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
lse_square_scale
=
lse_square_scale
...
@@ -230,10 +244,13 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -230,10 +244,13 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx
.
total_classes
=
total_classes
ctx
.
total_classes
=
total_classes
ctx
.
class_start_idx
=
class_start_idx
ctx
.
class_start_idx
=
class_start_idx
ctx
.
inplace_backward
=
inplace_backward
ctx
.
inplace_backward
=
inplace_backward
return
losses
return
losses
,
z_losses
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_losses
):
def
backward
(
ctx
,
grad_losses
,
grad_z_losses
):
del
grad_z_losses
# z_losses are only for logging.
logits
,
lse
,
labels
=
ctx
.
saved_tensors
logits
,
lse
,
labels
=
ctx
.
saved_tensors
dlogits
=
logits
if
ctx
.
inplace_backward
else
torch
.
empty_like
(
logits
)
dlogits
=
logits
if
ctx
.
inplace_backward
else
torch
.
empty_like
(
logits
)
n_rows
,
n_cols
=
logits
.
shape
n_rows
,
n_cols
=
logits
.
shape
...
@@ -262,8 +279,7 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -262,8 +279,7 @@ class CrossEntropyLoss(torch.autograd.Function):
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
BLOCK_SIZE
=
BLOCK_SIZE
,
# constants
num_warps
=
num_warps
,
num_warps
=
num_warps
,
)
)
return
dlogits
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dlogits
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
cross_entropy_loss
(
def
cross_entropy_loss
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
...
@@ -287,9 +303,10 @@ def cross_entropy_loss(
...
@@ -287,9 +303,10 @@ def cross_entropy_loss(
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
one part of the vocab. The loss will be aggregated across processes.
Returns:
Returns:
losses: (batch,), float
losses: (batch,), float
z_losses: (batch,), float
"""
"""
return
CrossEntropyLoss
.
apply
(
return
CrossEntropyLoss
.
apply
(
logits
,
logits
,
...
...
tests/losses/test_cross_entropy.py
View file @
d8aacc51
...
@@ -16,6 +16,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
...
@@ -16,6 +16,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_backward"
,
[
False
,
True
])
# @pytest.mark.parametrize("inplace_backward", [False])
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
@
pytest
.
mark
.
parametrize
(
"return_z_loss"
,
[
False
,
True
])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@
pytest
.
mark
.
parametrize
(
"logit_scale"
,
[
1.0
,
0.7
])
@
pytest
.
mark
.
parametrize
(
"logit_scale"
,
[
1.0
,
0.7
])
# @pytest.mark.parametrize("logit_scale", [1.0])
# @pytest.mark.parametrize("logit_scale", [1.0])
...
@@ -24,7 +25,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
...
@@ -24,7 +25,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50257
,
128
*
1024
])
# test vocab larger than 64k for split
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50257
,
128
*
1024
])
# test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
# @pytest.mark.parametrize("vocab_size", [12])
def
test_cross_entropy_loss
(
def
test_cross_entropy_loss
(
vocab_size
,
smoothing
,
logit_scale
,
lse_square_scale
,
inplace_backward
,
dtype
vocab_size
,
smoothing
,
logit_scale
,
lse_square_scale
,
return_z_loss
,
inplace_backward
,
dtype
):
):
device
=
"cuda"
device
=
"cuda"
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
...
@@ -44,14 +45,21 @@ def test_cross_entropy_loss(
...
@@ -44,14 +45,21 @@ def test_cross_entropy_loss(
label_smoothing
=
smoothing
,
label_smoothing
=
smoothing
,
logit_scale
=
logit_scale
,
logit_scale
=
logit_scale
,
lse_square_scale
=
lse_square_scale
,
lse_square_scale
=
lse_square_scale
,
return_z_loss
=
return_z_loss
,
inplace_backward
=
inplace_backward
,
inplace_backward
=
inplace_backward
,
)
)
out
=
model
(
x
,
y
)
if
return_z_loss
:
out
,
out_z_loss
=
model
(
x
,
y
)
else
:
out
=
model
(
x
,
y
)
x_pt_scaled
=
(
x_pt
.
float
()
*
logit_scale
)
if
logit_scale
!=
1.0
else
x_pt
.
float
()
x_pt_scaled
=
(
x_pt
.
float
()
*
logit_scale
)
if
logit_scale
!=
1.0
else
x_pt
.
float
()
out_pt
=
model_pt
(
x_pt_scaled
,
y
)
out_pt
=
model_pt
(
x_pt_scaled
,
y
)
if
lse_square_scale
>
0.0
:
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt_scaled
,
dim
=-
1
)
lse_pt
=
torch
.
logsumexp
(
x_pt_scaled
,
dim
=-
1
)
out_pt
+=
lse_square_scale
*
(
lse_pt
[
y
!=
-
100
]
**
2
).
mean
()
z_loss_pt
=
lse_square_scale
*
(
lse_pt
[
y
!=
-
100
]
**
2
).
mean
()
if
return_z_loss
:
assert
torch
.
allclose
(
out_z_loss
,
z_loss_pt
,
rtol
=
rtol
,
atol
=
atol
)
out_pt
+=
z_loss_pt
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment