"vscode:/vscode.git/clone" did not exist on "f7e814093a788eff8095bd87c34df7d92abac9f6"
Commit bc28eacc authored by Tri Dao's avatar Tri Dao
Browse files

Format flash_attn_interface.py

parent 0a146185
...@@ -43,7 +43,9 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): ...@@ -43,7 +43,9 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
return (128, 64) if is_sm80 else (64, 64) return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
...@@ -202,7 +204,9 @@ def _flash_attn_varlen_backward( ...@@ -202,7 +204,9 @@ def _flash_attn_varlen_backward(
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): def forward(
ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -322,7 +326,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -322,7 +326,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): def forward(
ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -452,7 +458,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -452,7 +458,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax): def forward(
ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -1079,6 +1087,6 @@ def flash_attn_with_kvcache( ...@@ -1079,6 +1087,6 @@ def flash_attn_with_kvcache(
window_size[1], window_size[1],
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
alibi_slopes alibi_slopes,
) )
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment