Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c79de85f
Commit
c79de85f
authored
Oct 24, 2023
by
Tri Dao
Browse files
[CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements
parent
02ac572f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
flash_attn/ops/triton/cross_entropy.py
flash_attn/ops/triton/cross_entropy.py
+3
-3
No files found.
flash_attn/ops/triton/cross_entropy.py
View file @
c79de85f
...
...
@@ -43,7 +43,7 @@ def cross_entropy_fwd_kernel(
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
logits
=
tl
.
load
(
logits_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=-
float
(
"inf"
)).
to
(
...
...
@@ -107,8 +107,8 @@ def cross_entropy_bwd_kernel(
):
row_idx
=
tl
.
program_id
(
0
)
col_block_idx
=
tl
.
program_id
(
1
)
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
logits_ptr
=
logits_ptr
+
row_idx
*
logits_row_stride
.
to
(
tl
.
int64
)
dlogits_ptr
=
dlogits_ptr
+
row_idx
*
dlogits_row_stride
.
to
(
tl
.
int64
)
col_offsets
=
col_block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
label_idx
=
tl
.
load
(
labels_ptr
+
row_idx
)
if
label_idx
!=
ignored_index
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment