Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
bcfa7c97
Commit
bcfa7c97
authored
Aug 16, 2023
by
Tri Dao
Browse files
[FusedDense] Run black on fused_dense.py
parent
2286d7ce
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
282 additions
and
129 deletions
+282
-129
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+1
-1
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+278
-128
pyproject.toml
pyproject.toml
+3
-0
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
bcfa7c97
...
...
@@ -822,7 +822,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements
not
beyond actual_seqlen_k.
// But we still want to mask out elements beyond actual_seqlen_k.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
...
...
flash_attn/ops/fused_dense.py
View file @
bcfa7c97
This diff is collapsed.
Click to expand it.
pyproject.toml
0 → 100644
View file @
bcfa7c97
[tool.black]
line-length
=
100
target-version
=
['py38']
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment