Commit c65b5106 authored by Tri Dao's avatar Tri Dao
Browse files

Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal

parent 0f7853c6
......@@ -43,7 +43,7 @@ jobs:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0.dev20230731']
cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
......@@ -52,6 +52,9 @@ jobs:
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# Pytorch <= 1.12 does not support Python 3.11
- torch-version: '1.12'
python-version: '3.11'
# Pytorch >= 2.0 only supports Python >= 3.8
- torch-version: '2.0.1'
python-version: '3.7'
......
......@@ -820,7 +820,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
} else {
// Putting this causal masking right after acc_s is *much* slower for some reason.
if (m_block * kBlockM < (n_block + 1) * kBlockN) {
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements not beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
......
__version__ = "2.0.7"
__version__ = "2.0.8"
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
......
......@@ -924,3 +924,36 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (q_pt.grad - q_ref.grad).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (k_pt.grad - k_ref.grad).abs().max().item()
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (v_pt.grad - v_ref.grad).abs().max().item()
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0 or varlen.
"""
device = 'cuda'
# set seed
torch.random.manual_seed(0)
nheads = 5
q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
Mq = 256
Mk = 3
q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
g = torch.randn_like(out)
out.backward(g)
assert not q.grad.isnan().any()
assert not k.grad.isnan().any()
assert not v.grad.isnan().any()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment