Commit 2ed471ec authored by Tri Dao's avatar Tri Dao
Browse files

Add tests for numerical error

parent 42f54d88
......@@ -116,6 +116,17 @@ T4 GPUs are commonly used for inference, so we also measure speedup on the forwa
We see speedups between 2.5x-4.5x on the forward pass.
## Tests
We test that FlashAttention produces the same output and gradient as a reference
implementation, up to some numerical tolerance. In particular, we check that the
maximum numerical error of FlashAttention is at most twice the numerical error
of a baseline implementation in Pytorch (for different head dimensions, input
dtype, sequence length, causal / non-causal).
To run the tests:
```
pytest -q -s tests/test_flash_attn.py
```
## When you encounter issues
This alpha release of FlashAttention contains code written for a research
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment