Unverified Commit 8f48a546 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

use global function rather than lambda (#7)

parent 537f75eb
...@@ -11,6 +11,8 @@ import vllm_flash_attn_2_cuda as flash_attn_cuda ...@@ -11,6 +11,8 @@ import vllm_flash_attn_2_cuda as flash_attn_cuda
# isort: on # isort: on
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _get_block_size_n(device, head_dim, is_dropout, is_causal): def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel # This should match the block sizes in the CUDA kernel
...@@ -46,7 +48,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): ...@@ -46,7 +48,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward( def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q, q,
...@@ -83,7 +84,6 @@ def _flash_attn_varlen_forward( ...@@ -83,7 +84,6 @@ def _flash_attn_varlen_forward(
*, *,
out=None out=None
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q, q,
...@@ -129,7 +129,6 @@ def _flash_attn_backward( ...@@ -129,7 +129,6 @@ def _flash_attn_backward(
deterministic, deterministic,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
...@@ -177,7 +176,6 @@ def _flash_attn_varlen_backward( ...@@ -177,7 +176,6 @@ def _flash_attn_varlen_backward(
deterministic, deterministic,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
...@@ -1219,7 +1217,6 @@ def flash_attn_with_kvcache( ...@@ -1219,7 +1217,6 @@ def flash_attn_with_kvcache(
""" """
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment