Commit 13403e81 authored by Tri Dao's avatar Tri Dao
Browse files

Relax assert to allow both bf16 and fp16

parent 64f42cd0
......@@ -34,7 +34,7 @@ class FlashAttention(nn.Module):
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype == torch.float16
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment