flash_extension.cc 2.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>

#include "sgl_flash_kernel_ops.h"

TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
  /*
   * From flash-attention
   */
  m.def(
Johnny's avatar
Johnny committed
26
27
28
29
30
31
32
33
34
35
36
37
      "fwd(Tensor   q,"                 // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
      "    Tensor   k,"                 // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged
      "    Tensor   v,"                 // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged
      "    Tensor?  k_new,"             // (b, s_k_new, h_k, d) or (total_k_new, h_k, d)
      "    Tensor?  v_new,"             // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv)
      "    Tensor?  q_v,"               // (b, s_q, h, dv) or (total_q_new, h, dv)
      "    Tensor?  out,"               // (b, s_q, h, dv) or (total_q, h, dv)
      "    Tensor?  cu_seqlens_q,"      // b+1
      "    Tensor?  cu_seqlens_k,"      // b+1
      "    Tensor?  cu_seqlens_k_new,"  // b+1
      "    Tensor?  seqused_q,"         // b
      "    Tensor?  seqused_k,"         // b
38
      "    int?     max_seqlen_q,"
Johnny's avatar
Johnny committed
39
40
41
42
43
44
45
46
47
48
49
      "    int?     max_seqlen_k,"    // TODO: check if needed
      "    Tensor?  page_table,"      // (b_k, max_num_pages_per_seq)
      "    Tensor?  kv_batch_idx,"    // b
      "    Tensor?  leftpad_k,"       // b
      "    Tensor?  rotary_cos,"      // seqlen_ro x (rotary_dim / 2)
      "    Tensor?  rotary_sin,"      // seqlen_ro x (rotary_dim / 2)
      "    Tensor?  seqlens_rotary,"  // b
      "    Tensor?  q_descale,"       // (b, h_k)
      "    Tensor?  k_descale,"       // (b, h_k)
      "    Tensor?  v_descale,"       // (b, h_k)
      "    float?   softmax_scale,"   // now optional
50
51
52
      "    bool     is_causal,"
      "    int      window_size_left,"
      "    int      window_size_right,"
Johnny's avatar
Johnny committed
53
54
      "    int      attention_chunk,"  // NEW
      "    float    softcap,"          // promoted to double in C++; schema float is fine
55
      "    bool     is_rotary_interleaved,"
Johnny's avatar
Johnny committed
56
      "    Tensor?  scheduler_metadata,"  // (b + 1)
57
58
      "    int      num_splits,"
      "    bool?    pack_gqa,"
59
      "    int      sm_margin,"
Johnny's avatar
Johnny committed
60
61
62
      "    Tensor?  sinks"
      ") -> (Tensor, Tensor, Tensor, Tensor)");  // NEW return type: tuple of 4 tensors

63
64
65
66
  m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}

REGISTER_EXTENSION(flash_ops)