Commit bcfa7c97 authored by Tri Dao's avatar Tri Dao
Browse files

[FusedDense] Run black on fused_dense.py

parent 2286d7ce
......@@ -822,7 +822,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements not beyond actual_seqlen_k.
// But we still want to mask out elements beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
......
This diff is collapsed.
[tool.black]
line-length = 100
target-version = ['py38']
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment