Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
08124c8f
Commit
08124c8f
authored
Dec 16, 2023
by
Tri Dao
Browse files
[CrossEntropy] Implement logit_scale option
parent
9356a1c0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
11 deletions
+32
-11
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+3
-0
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+15
-5
tests/losses/test_cross_entropy.py
tests/losses/test_cross_entropy.py
+8
-3
tests/losses/test_cross_entropy_parallel.py
tests/losses/test_cross_entropy_parallel.py
+6
-3
No files found.
flash_attn/losses/cross_entropy.py
View file @
08124c8f
...
...
@@ -12,6 +12,7 @@ class CrossEntropyLoss(nn.Module):
ignore_index
=-
100
,
reduction
=
"mean"
,
label_smoothing
=
0.0
,
logit_scale
=
1.0
,
lse_square_scale
=
0.0
,
inplace_backward
=
False
,
process_group
=
None
,
...
...
@@ -33,6 +34,7 @@ class CrossEntropyLoss(nn.Module):
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
label_smoothing
=
label_smoothing
self
.
logit_scale
=
logit_scale
self
.
lse_square_scale
=
lse_square_scale
self
.
inplace_backward
=
inplace_backward
self
.
process_group
=
process_group
...
...
@@ -50,6 +52,7 @@ class CrossEntropyLoss(nn.Module):
input
,
target
,
label_smoothing
=
self
.
label_smoothing
,
logit_scale
=
self
.
logit_scale
,
lse_square_scale
=
self
.
lse_square_scale
,
ignored_index
=
self
.
ignore_index
,
inplace_backward
=
self
.
inplace_backward
,
...
...
flash_attn/ops/triton/cross_entropy.py
View file @
08124c8f
...
...
@@ -29,6 +29,7 @@ def cross_entropy_fwd_kernel(
logits_ptr
,
labels_ptr
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
total_classes
,
...
...
@@ -48,7 +49,7 @@ def cross_entropy_fwd_kernel(
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
)
*
logit_scale
max_logits
=
tl
.
max
(
logits
,
0
)
if
HAS_SMOOTHING
:
sum_logits
=
tl
.
sum
(
tl
.
where
(
col_offsets
<
n_cols
,
logits
,
0.0
),
0
)
...
...
@@ -61,7 +62,7 @@ def cross_entropy_fwd_kernel(
if
label_idx
>=
col_block_idx
*
BLOCK_SIZE
and
label_idx
<
min
(
n_cols
,
(
col_block_idx
+
1
)
*
BLOCK_SIZE
):
logits_label
=
tl
.
load
(
logits_ptr
+
label_idx
)
logits_label
=
tl
.
load
(
logits_ptr
+
label_idx
)
*
logit_scale
if
HAS_SMOOTHING
:
loss
=
(
(
lse
if
not
SPLIT
else
0.0
)
...
...
@@ -94,6 +95,7 @@ def cross_entropy_bwd_kernel(
lse_ptr
,
labels_ptr
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
total_classes
,
...
...
@@ -117,7 +119,7 @@ def cross_entropy_bwd_kernel(
dloss
=
0.0
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
tl
.
float32
)
)
*
logit_scale
lse
=
tl
.
load
(
lse_ptr
+
row_idx
)
probs
=
tl
.
exp
(
logits
-
lse
)
probs
+=
2.0
*
lse_square_scale
*
lse
*
probs
...
...
@@ -128,16 +130,18 @@ def cross_entropy_bwd_kernel(
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
(
1
-
smoothing
),
probs
)
-
smooth_negative
else
:
probs
=
tl
.
where
(
col_offsets
==
label_idx
,
probs
-
1.0
,
probs
)
tl
.
store
(
dlogits_ptr
+
col_offsets
,
dloss
*
probs
,
mask
=
col_offsets
<
n_cols
)
tl
.
store
(
dlogits_ptr
+
col_offsets
,
(
dloss
*
logit_scale
)
*
probs
,
mask
=
col_offsets
<
n_cols
)
class
CrossEntropyLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
,
smoothing
=
0.0
,
logit_scale
=
1.0
,
lse_square_scale
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
...
...
@@ -177,6 +181,7 @@ class CrossEntropyLoss(torch.autograd.Function):
logits
,
labels
,
smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
total_classes
,
...
...
@@ -219,6 +224,7 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
smoothing
=
smoothing
ctx
.
logit_scale
=
logit_scale
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
ignored_index
=
ignored_index
ctx
.
total_classes
=
total_classes
...
...
@@ -244,6 +250,7 @@ class CrossEntropyLoss(torch.autograd.Function):
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
logit_scale
,
ctx
.
lse_square_scale
,
ctx
.
ignored_index
,
ctx
.
total_classes
,
...
...
@@ -262,6 +269,7 @@ def cross_entropy_loss(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
label_smoothing
:
float
=
0.0
,
logit_scale
:
float
=
1.0
,
lse_square_scale
:
float
=
0.0
,
ignored_index
=-
100
,
inplace_backward
:
bool
=
False
,
...
...
@@ -272,6 +280,7 @@ def cross_entropy_loss(
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
...
...
@@ -286,6 +295,7 @@ def cross_entropy_loss(
logits
,
labels
,
label_smoothing
,
logit_scale
,
lse_square_scale
,
ignored_index
,
inplace_backward
,
...
...
tests/losses/test_cross_entropy.py
View file @
08124c8f
...
...
@@ -17,11 +17,15 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@
pytest
.
mark
.
parametrize
(
"logit_scale"
,
[
1.0
,
0.7
])
# @pytest.mark.parametrize("logit_scale", [1.0])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
# @pytest.mark.parametrize("smoothing", [0.0])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50257
,
128
*
1024
])
# test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def
test_cross_entropy_loss
(
vocab_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
):
def
test_cross_entropy_loss
(
vocab_size
,
smoothing
,
logit_scale
,
lse_square_scale
,
inplace_backward
,
dtype
):
device
=
"cuda"
rtol
,
atol
=
(
1e-5
,
1e-6
)
if
dtype
==
torch
.
float32
else
(
1e-3
,
1e-4
)
# set seed
...
...
@@ -38,13 +42,14 @@ def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_bac
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
label_smoothing
=
smoothing
)
model
=
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
logit_scale
=
logit_scale
,
lse_square_scale
=
lse_square_scale
,
inplace_backward
=
inplace_backward
,
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
out_pt
=
model_pt
(
x_pt
.
float
()
*
logit_scale
,
y
)
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
(),
dim
=-
1
)
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
()
*
logit_scale
,
dim
=-
1
)
out_pt
+=
lse_square_scale
*
(
lse_pt
[
y
!=
-
100
]
**
2
).
mean
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
...
...
tests/losses/test_cross_entropy_parallel.py
View file @
08124c8f
...
...
@@ -19,6 +19,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("inplace_backward", [False])
@
pytest
.
mark
.
parametrize
(
"lse_square_scale"
,
[
0.0
,
1e-2
])
# @pytest.mark.parametrize("lse_square_scale", [0.0])
@
pytest
.
mark
.
parametrize
(
"logit_scale"
,
[
0.7
])
# @pytest.mark.parametrize("logit_scale", [1.0])
@
pytest
.
mark
.
parametrize
(
"smoothing"
,
[
0.0
,
0.9
])
# @pytest.mark.parametrize("smoothing", [0.0])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
50264
,
256
*
1024
])
# test vocab larger than 64k for split
...
...
@@ -26,7 +28,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("world_size", [1, 2])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
def
test_cross_entropy_loss_parallel
(
vocab_size
,
world_size
,
smoothing
,
lse_square_scale
,
inplace_backward
,
dtype
vocab_size
,
world_size
,
smoothing
,
logit_scale
,
lse_square_scale
,
inplace_backward
,
dtype
):
assert
vocab_size
%
world_size
==
0
rtol
,
atol
=
(
...
...
@@ -59,15 +61,16 @@ def test_cross_entropy_loss_parallel(
model_pt
=
torch
.
nn
.
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
reduction
=
"none"
)
model
=
CrossEntropyLoss
(
label_smoothing
=
smoothing
,
logit_scale
=
logit_scale
,
reduction
=
"none"
,
lse_square_scale
=
lse_square_scale
,
inplace_backward
=
inplace_backward
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
)
out
=
model
(
x
,
y
)
out_pt
=
model_pt
(
x_pt
.
float
(),
y
)
out_pt
=
model_pt
(
x_pt
.
float
()
*
logit_scale
,
y
)
if
lse_square_scale
>
0.0
:
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
(),
dim
=-
1
)
lse_pt
=
torch
.
logsumexp
(
x_pt
.
float
()
*
logit_scale
,
dim
=-
1
)
out_pt
+=
lse_square_scale
*
lse_pt
.
square
()
out_pt
.
masked_fill_
(
y
==
-
100
,
0.0
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
1e-5
,
atol
=
1e-6
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment