Unverified Commit f792e3c5 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "[NVIDIA] BUMP FA3 (#11444)" (#11582)

parent 28f80b12
...@@ -90,7 +90,7 @@ FetchContent_Populate(repo-flashinfer) ...@@ -90,7 +90,7 @@ FetchContent_Populate(repo-flashinfer)
FetchContent_Declare( FetchContent_Declare(
repo-flash-attention repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG 36f9456cd48ec57c8d75d8d6b90933d4bedffb6b GIT_TAG f9af0c2a1d82ab1812e6987e9338363cc2bf0f8d
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flash-attention) FetchContent_Populate(repo-flash-attention)
...@@ -99,7 +99,7 @@ FetchContent_Populate(repo-flash-attention) ...@@ -99,7 +99,7 @@ FetchContent_Populate(repo-flash-attention)
FetchContent_Declare( FetchContent_Declare(
repo-flash-attention-origin repo-flash-attention-origin
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_TAG 5a5a65b48dc99fc7483d2a7d5cfb1d8befa89389 GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flash-attention-origin) FetchContent_Populate(repo-flash-attention-origin)
......
...@@ -23,43 +23,40 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -23,43 +23,40 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From flash-attention * From flash-attention
*/ */
m.def( m.def(
"fwd(Tensor q," // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q "fwd(Tensor! q,"
" Tensor k," // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged " Tensor k,"
" Tensor v," // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged " Tensor v,"
" Tensor? k_new," // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) " Tensor? k_new,"
" Tensor? v_new," // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) " Tensor? v_new,"
" Tensor? q_v," // (b, s_q, h, dv) or (total_q_new, h, dv) " Tensor? q_v,"
" Tensor? out," // (b, s_q, h, dv) or (total_q, h, dv) " Tensor!? out,"
" Tensor? cu_seqlens_q," // b+1 " Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k," // b+1 " Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new," // b+1 " Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q," // b " Tensor? seqused_q,"
" Tensor? seqused_k," // b " Tensor? seqused_k,"
" int? max_seqlen_q," " int? max_seqlen_q,"
" int? max_seqlen_k," // TODO: check if needed " int? max_seqlen_k,"
" Tensor? page_table," // (b_k, max_num_pages_per_seq) " Tensor? page_table,"
" Tensor? kv_batch_idx," // b " Tensor? kv_batch_idx,"
" Tensor? leftpad_k," // b " Tensor? leftpad_k,"
" Tensor? rotary_cos," // seqlen_ro x (rotary_dim / 2) " Tensor? rotary_cos,"
" Tensor? rotary_sin," // seqlen_ro x (rotary_dim / 2) " Tensor? rotary_sin,"
" Tensor? seqlens_rotary," // b " Tensor? seqlens_rotary,"
" Tensor? q_descale," // (b, h_k) " Tensor? q_descale,"
" Tensor? k_descale," // (b, h_k) " Tensor? k_descale,"
" Tensor? v_descale," // (b, h_k) " Tensor? v_descale,"
" float? softmax_scale," // now optional " float softmax_scale,"
" bool is_causal," " bool is_causal,"
" int window_size_left," " int window_size_left,"
" int window_size_right," " int window_size_right,"
" int attention_chunk," // NEW " float softcap,"
" float softcap," // promoted to double in C++; schema float is fine
" bool is_rotary_interleaved," " bool is_rotary_interleaved,"
" Tensor? scheduler_metadata," // (b + 1) " Tensor? scheduler_metadata,"
" int num_splits," " int num_splits,"
" bool? pack_gqa," " bool? pack_gqa,"
" int sm_margin," " int sm_margin,"
" Tensor? sinks" " Tensor? sinks) -> Tensor[]");
") -> (Tensor, Tensor, Tensor, Tensor)"); // NEW return type: tuple of 4 tensors
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
} }
......
...@@ -42,44 +42,45 @@ limitations under the License. ...@@ -42,44 +42,45 @@ limitations under the License.
/* /*
* From flash-attention * From flash-attention
*/ */
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_fwd( std::vector<at::Tensor> mha_fwd(
at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table. // h_k, d) if there is page_table.
at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table. // page_size, h_k, dv) if there is page_table.
std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new std::optional<const at::Tensor>&
std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional<const at::Tensor>&
std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<at::Tensor> cu_seqlens_q_, // b+1 std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> cu_seqlens_k_, // b+1 std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1 std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<at::Tensor> std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<at::Tensor> std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used. seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int64_t> max_seqlen_q_, std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k // TODO: check if we need max_seqlen_k
std::optional<int64_t> max_seqlen_k_, std::optional<int> max_seqlen_k_,
std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq) std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<at::Tensor> leftpad_k_, // b std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<at::Tensor> seqlens_rotary_, // b std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h) std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor> k_descale_, // (b, h_k) std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor> v_descale_, // (b, h_k) std::optional<at::Tensor>& v_descale_, // (b, h_k)
std::optional<double> softmax_scale_, float const softmax_scale,
bool is_causal, bool is_causal,
int64_t window_size_left, int window_size_left,
int64_t window_size_right, int window_size_right,
int64_t attention_chunk, float const softcap,
double softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
std::optional<at::Tensor> scheduler_metadata_, // (b + 1) int num_splits,
int64_t num_splits,
std::optional<bool> pack_gqa_, std::optional<bool> pack_gqa_,
int64_t sm_margin, int const sm_margin,
std::optional<const at::Tensor>& sinks_); // (h) std::optional<const at::Tensor>& sinks_);
...@@ -43,7 +43,7 @@ def flash_attn_with_kvcache( ...@@ -43,7 +43,7 @@ def flash_attn_with_kvcache(
qv=None, qv=None,
rotary_cos=None, rotary_cos=None,
rotary_sin=None, rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None,
...@@ -57,7 +57,6 @@ def flash_attn_with_kvcache( ...@@ -57,7 +57,6 @@ def flash_attn_with_kvcache(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True, rotary_interleaved=True,
scheduler_metadata=None, scheduler_metadata=None,
...@@ -136,7 +135,6 @@ def flash_attn_with_kvcache( ...@@ -136,7 +135,6 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
softcap: float. Anything > 0 activates softcapping attention. softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
...@@ -216,7 +214,6 @@ def flash_attn_with_kvcache( ...@@ -216,7 +214,6 @@ def flash_attn_with_kvcache(
] ]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens) rotary_seqlens = maybe_contiguous(rotary_seqlens)
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q, q,
...@@ -246,7 +243,6 @@ def flash_attn_with_kvcache( ...@@ -246,7 +243,6 @@ def flash_attn_with_kvcache(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
attention_chunk,
softcap, softcap,
rotary_interleaved, rotary_interleaved,
scheduler_metadata, scheduler_metadata,
...@@ -276,7 +272,6 @@ def flash_attn_varlen_func( ...@@ -276,7 +272,6 @@ def flash_attn_varlen_func(
k_descale=None, k_descale=None,
v_descale=None, v_descale=None,
window_size=(-1, -1), window_size=(-1, -1),
attention_chunk: Optional[int] = None,
softcap=0.0, softcap=0.0,
num_splits=1, num_splits=1,
pack_gqa=None, pack_gqa=None,
...@@ -326,7 +321,6 @@ def flash_attn_varlen_func( ...@@ -326,7 +321,6 @@ def flash_attn_varlen_func(
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5 -0.5
) )
attention_chunk = 0 if attention_chunk is None else int(attention_chunk)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q, q,
...@@ -356,7 +350,6 @@ def flash_attn_varlen_func( ...@@ -356,7 +350,6 @@ def flash_attn_varlen_func(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
attention_chunk,
softcap, softcap,
is_rotary_interleaved=False, is_rotary_interleaved=False,
scheduler_metadata=None, scheduler_metadata=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment