Unverified Commit cbf982af authored by Ian Timmis's avatar Ian Timmis Committed by GitHub
Browse files

README syntax highlighting (#365)

* README syntax highlighting

Adds syntax highlighting to README

* Update README.md
parent 425dbcb6
...@@ -50,7 +50,7 @@ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. ...@@ -50,7 +50,7 @@ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
pip install flash-attn --no-build-isolation pip install flash-attn --no-build-isolation
``` ```
Alternatively you can compile from source: Alternatively you can compile from source:
``` ```sh
python setup.py install python setup.py install
``` ```
...@@ -58,7 +58,7 @@ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might ...@@ -58,7 +58,7 @@ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
run too many parallel compilation jobs that could exhaust the amount of RAM. To run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment limit the number of parallel compilation jobs, you can set the environment
variable `MAX_JOBS`: variable `MAX_JOBS`:
``` ```sh
MAX_JOBS=4 pip install flash-attn --no-build-isolation MAX_JOBS=4 pip install flash-attn --no-build-isolation
``` ```
...@@ -76,11 +76,11 @@ FlashAttention-2 currently supports: ...@@ -76,11 +76,11 @@ FlashAttention-2 currently supports:
The main functions implement scaled dot product attention (softmax(Q @ K^T * The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V): softmax_scale) @ V):
``` ```python
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
``` ```
``` ```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
...@@ -94,9 +94,10 @@ Arguments: ...@@ -94,9 +94,10 @@ Arguments:
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
"""
``` ```
``` ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...@@ -114,6 +115,7 @@ Arguments: ...@@ -114,6 +115,7 @@ Arguments:
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
"""
``` ```
To see how these functions are used in a multi-head attention layer (which To see how these functions are used in a multi-head attention layer (which
...@@ -128,10 +130,10 @@ These functions have been renamed: ...@@ -128,10 +130,10 @@ These functions have been renamed:
If the inputs have the same sequence lengths in the same batch, it is simpler If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions: and faster to use these functions:
``` ```python
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False) flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
``` ```
``` ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
``` ```
...@@ -205,7 +207,7 @@ of a baseline implementation in Pytorch (for different head dimensions, input ...@@ -205,7 +207,7 @@ of a baseline implementation in Pytorch (for different head dimensions, input
dtype, sequence length, causal / non-causal). dtype, sequence length, causal / non-causal).
To run the tests: To run the tests:
``` ```sh
pytest -q -s tests/test_flash_attn.py pytest -q -s tests/test_flash_attn.py
``` ```
## When you encounter issues ## When you encounter issues
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment