Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ec6d2214
"vscode:/vscode.git/clone" did not exist on "8091e3482d2fa8acdf3ffd4f8027c5fb5f298f1c"
Commit
ec6d2214
authored
Apr 26, 2024
by
Tri Dao
Browse files
[CrossEntropy] Change ignored_index -> ignore_index
parent
85881f54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
15 deletions
+15
-15
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+2
-2
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+13
-13
No files found.
flash_attn/losses/cross_entropy.py
View file @
ec6d2214
...
@@ -20,7 +20,7 @@ class CrossEntropyLoss(nn.Module):
...
@@ -20,7 +20,7 @@ class CrossEntropyLoss(nn.Module):
):
):
"""
"""
Arguments:
Arguments:
ignore
d
_index: int. If labels == ignore
d
_index, the loss is set to 0.0.
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
label_smoothing: float
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
This is also referred to as "z-loss".
...
@@ -60,7 +60,7 @@ class CrossEntropyLoss(nn.Module):
...
@@ -60,7 +60,7 @@ class CrossEntropyLoss(nn.Module):
label_smoothing
=
self
.
label_smoothing
,
label_smoothing
=
self
.
label_smoothing
,
logit_scale
=
self
.
logit_scale
,
logit_scale
=
self
.
logit_scale
,
lse_square_scale
=
self
.
lse_square_scale
,
lse_square_scale
=
self
.
lse_square_scale
,
ignore
d
_index
=
self
.
ignore_index
,
ignore_index
=
self
.
ignore_index
,
inplace_backward
=
self
.
inplace_backward
,
inplace_backward
=
self
.
inplace_backward
,
process_group
=
self
.
process_group
,
process_group
=
self
.
process_group
,
)
)
...
...
flash_attn/ops/triton/cross_entropy.py
View file @
ec6d2214
...
@@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel(
...
@@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel(
smoothing
,
smoothing
,
logit_scale
,
logit_scale
,
lse_square_scale
,
lse_square_scale
,
ignore
d
_index
,
ignore_index
,
total_classes
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
n_cols
,
# shapes
...
@@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel(
...
@@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel(
sum_logits
=
tl
.
sum
(
tl
.
where
(
col_offsets
<
n_cols
,
logits
,
0.0
),
0
)
sum_logits
=
tl
.
sum
(
tl
.
where
(
col_offsets
<
n_cols
,
logits
,
0.0
),
0
)
lse
=
tl
.
log
(
tl
.
sum
(
tl
.
exp
(
logits
-
max_logits
),
0
))
+
max_logits
lse
=
tl
.
log
(
tl
.
sum
(
tl
.
exp
(
logits
-
max_logits
),
0
))
+
max_logits
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
tl
.
store
(
lse_ptr
+
col_block_idx
*
n_rows
+
row_idx
,
lse
)
if
label_idx
==
ignore
d
_index
:
if
label_idx
==
ignore_index
:
loss
=
0.0
loss
=
0.0
z_loss
=
0.0
z_loss
=
0.0
else
:
else
:
...
@@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel(
...
@@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel(
smoothing
,
smoothing
,
logit_scale
,
logit_scale
,
lse_square_scale
,
lse_square_scale
,
ignore
d
_index
,
ignore_index
,
total_classes
,
total_classes
,
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
class_start_idx
,
# Useful for tensor parallel when each rank only has a subset of classes
n_cols
,
# shapes
n_cols
,
# shapes
...
@@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel(
...
@@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel(
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
.
to
(
tl
.
int64
)
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
if
label_idx
!=
ignore
d
_index
:
if
label_idx
!=
ignore_index
:
dloss
=
tl
.
load
(
dloss_ptr
+
row_idx
*
dloss_row_stride
)
dloss
=
tl
.
load
(
dloss_ptr
+
row_idx
*
dloss_row_stride
)
else
:
else
:
dloss
=
0.0
dloss
=
0.0
...
@@ -150,7 +150,7 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -150,7 +150,7 @@ class CrossEntropyLoss(torch.autograd.Function):
smoothing
=
0.0
,
smoothing
=
0.0
,
logit_scale
=
1.0
,
logit_scale
=
1.0
,
lse_square_scale
=
0.0
,
lse_square_scale
=
0.0
,
ignore
d
_index
=-
100
,
ignore_index
=-
100
,
inplace_backward
=
False
,
inplace_backward
=
False
,
process_group
=
None
,
process_group
=
None
,
):
):
...
@@ -192,7 +192,7 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -192,7 +192,7 @@ class CrossEntropyLoss(torch.autograd.Function):
smoothing
,
smoothing
,
logit_scale
,
logit_scale
,
lse_square_scale
,
lse_square_scale
,
ignore
d
_index
,
ignore_index
,
total_classes
,
total_classes
,
class_start_idx
,
class_start_idx
,
n_cols
,
# shapes
n_cols
,
# shapes
...
@@ -229,18 +229,18 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -229,18 +229,18 @@ class CrossEntropyLoss(torch.autograd.Function):
losses
+=
lse
losses
+=
lse
if
lse_square_scale
!=
0.0
:
if
lse_square_scale
!=
0.0
:
z_losses
=
lse_square_scale
*
lse
.
square
()
z_losses
=
lse_square_scale
*
lse
.
square
()
z_losses
.
masked_fill_
(
labels
==
ignore
d
_index
,
0.0
)
z_losses
.
masked_fill_
(
labels
==
ignore_index
,
0.0
)
losses
+=
z_losses
losses
+=
z_losses
else
:
else
:
z_losses
=
torch
.
zeros_like
(
losses
)
z_losses
=
torch
.
zeros_like
(
losses
)
losses
.
masked_fill_
(
labels
==
ignore
d
_index
,
0.0
)
losses
.
masked_fill_
(
labels
==
ignore_index
,
0.0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels
)
ctx
.
mark_non_differentiable
(
z_losses
)
ctx
.
mark_non_differentiable
(
z_losses
)
ctx
.
smoothing
=
smoothing
ctx
.
smoothing
=
smoothing
ctx
.
logit_scale
=
logit_scale
ctx
.
logit_scale
=
logit_scale
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
lse_square_scale
=
lse_square_scale
ctx
.
ignore
d
_index
=
ignore
d
_index
ctx
.
ignore_index
=
ignore_index
ctx
.
total_classes
=
total_classes
ctx
.
total_classes
=
total_classes
ctx
.
class_start_idx
=
class_start_idx
ctx
.
class_start_idx
=
class_start_idx
ctx
.
inplace_backward
=
inplace_backward
ctx
.
inplace_backward
=
inplace_backward
...
@@ -269,7 +269,7 @@ class CrossEntropyLoss(torch.autograd.Function):
...
@@ -269,7 +269,7 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx
.
smoothing
,
ctx
.
smoothing
,
ctx
.
logit_scale
,
ctx
.
logit_scale
,
ctx
.
lse_square_scale
,
ctx
.
lse_square_scale
,
ctx
.
ignore
d
_index
,
ctx
.
ignore_index
,
ctx
.
total_classes
,
ctx
.
total_classes
,
ctx
.
class_start_idx
,
ctx
.
class_start_idx
,
n_cols
,
# shapes
n_cols
,
# shapes
...
@@ -287,7 +287,7 @@ def cross_entropy_loss(
...
@@ -287,7 +287,7 @@ def cross_entropy_loss(
label_smoothing
:
float
=
0.0
,
label_smoothing
:
float
=
0.0
,
logit_scale
:
float
=
1.0
,
logit_scale
:
float
=
1.0
,
lse_square_scale
:
float
=
0.0
,
lse_square_scale
:
float
=
0.0
,
ignore
d
_index
=-
100
,
ignore_index
=-
100
,
inplace_backward
:
bool
=
False
,
inplace_backward
:
bool
=
False
,
process_group
=
None
,
process_group
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -299,7 +299,7 @@ def cross_entropy_loss(
...
@@ -299,7 +299,7 @@ def cross_entropy_loss(
logit_scale: float. Multiply logits by this scale before calculating the loss.
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
This is also referred to as "z-loss".
ignore
d
_index: int. If labels == ignore
d
_index, the loss is set to 0.0.
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
...
@@ -314,7 +314,7 @@ def cross_entropy_loss(
...
@@ -314,7 +314,7 @@ def cross_entropy_loss(
label_smoothing
,
label_smoothing
,
logit_scale
,
logit_scale
,
lse_square_scale
,
lse_square_scale
,
ignore
d
_index
,
ignore_index
,
inplace_backward
,
inplace_backward
,
process_group
,
process_group
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment