Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
e7234756
"vscode:/vscode.git/clone" did not exist on "7d0c2729399c3ce019a30fc175b973e892fd5fc3"
Unverified
Commit
e7234756
authored
Oct 11, 2022
by
shenggan
Committed by
GitHub
Oct 11, 2022
Browse files
[hotfix] fix row_idx overflow in triton softmax (#80)
parent
ec9352d1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
fastfold/model/fastnn/kernel/triton/softmax.py
fastfold/model/fastnn/kernel/triton/softmax.py
+4
-4
No files found.
fastfold/model/fastnn/kernel/triton/softmax.py
View file @
e7234756
...
@@ -50,7 +50,7 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_
...
@@ -50,7 +50,7 @@ def _softmax_grad_core(output_ptrs, d_output_ptrs, d_input_ptrs, mask_ptrs, col_
def
softmax_mask_bias_kernel
(
output_ptr
,
input_ptr
,
mask_ptr
,
bias_ptr
,
input_row_stride
,
def
softmax_mask_bias_kernel
(
output_ptr
,
input_ptr
,
mask_ptr
,
bias_ptr
,
input_row_stride
,
output_row_stride
,
n_cols
,
n_heads
,
BLOCK_SIZE
:
tl
.
constexpr
,
output_row_stride
,
n_cols
,
n_heads
,
BLOCK_SIZE
:
tl
.
constexpr
,
use_mask
:
tl
.
constexpr
,
use_bias
:
tl
.
constexpr
):
use_mask
:
tl
.
constexpr
,
use_bias
:
tl
.
constexpr
):
row_idx
=
tl
.
program_id
(
0
)
row_idx
=
tl
.
program_id
(
0
)
.
to
(
tl
.
int64
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
input_row_ptr
=
input_ptr
+
row_idx
*
input_row_stride
input_row_ptr
=
input_ptr
+
row_idx
*
input_row_stride
...
@@ -77,7 +77,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro
...
@@ -77,7 +77,7 @@ def softmax_mask_bias_kernel(output_ptr, input_ptr, mask_ptr, bias_ptr, input_ro
def
softmax_mask_bias_kernel_two_rows
(
output_ptr
,
input_ptr
,
mask_ptr
,
bias_ptr
,
input_row_stride
,
def
softmax_mask_bias_kernel_two_rows
(
output_ptr
,
input_ptr
,
mask_ptr
,
bias_ptr
,
input_row_stride
,
output_row_stride
,
n_cols
,
n_heads
,
BLOCK_SIZE
:
tl
.
constexpr
,
output_row_stride
,
n_cols
,
n_heads
,
BLOCK_SIZE
:
tl
.
constexpr
,
use_mask
:
tl
.
constexpr
,
use_bias
:
tl
.
constexpr
):
use_mask
:
tl
.
constexpr
,
use_bias
:
tl
.
constexpr
):
row_idx
=
tl
.
program_id
(
0
)
row_idx
=
tl
.
program_id
(
0
)
.
to
(
tl
.
int64
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
input_row_ptr
=
input_ptr
+
2
*
row_idx
*
input_row_stride
input_row_ptr
=
input_ptr
+
2
*
row_idx
*
input_row_stride
...
@@ -119,7 +119,7 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_
...
@@ -119,7 +119,7 @@ def softmax_mask_grad_kernel(d_output_ptr, output_ptr, d_input_ptr, mask_ptr, d_
BLOCK_SIZE
:
tl
.
constexpr
,
is_bf16
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
is_bf16
:
tl
.
constexpr
,
use_mask
:
tl
.
constexpr
):
use_mask
:
tl
.
constexpr
):
row_idx
=
tl
.
program_id
(
0
)
row_idx
=
tl
.
program_id
(
0
)
.
to
(
tl
.
int64
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
output_row_ptr
=
output_ptr
+
row_idx
*
output_row_stride
output_row_ptr
=
output_ptr
+
row_idx
*
output_row_stride
...
@@ -145,7 +145,7 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas
...
@@ -145,7 +145,7 @@ def softmax_mask_grad_kernel_two_rows(d_output_ptr, output_ptr, d_input_ptr, mas
n_cols
,
n_heads
,
BLOCK_SIZE
:
tl
.
constexpr
,
n_cols
,
n_heads
,
BLOCK_SIZE
:
tl
.
constexpr
,
is_bf16
:
tl
.
constexpr
,
use_mask
:
tl
.
constexpr
):
is_bf16
:
tl
.
constexpr
,
use_mask
:
tl
.
constexpr
):
row_idx
=
tl
.
program_id
(
0
)
row_idx
=
tl
.
program_id
(
0
)
.
to
(
tl
.
int64
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
output_row_ptr
=
output_ptr
+
2
*
row_idx
*
output_row_stride
output_row_ptr
=
output_ptr
+
2
*
row_idx
*
output_row_stride
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment