lightning_attention_decode_kernel.cu 5.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
19
20
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
21
#include <torch/extension.h>
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

#define THREADS_PER_BLOCK 128

template <typename T>
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q,            // [b, h, 1, d]
                                                  const T* __restrict__ k,            // [b, h, 1, d]
                                                  const T* __restrict__ v,            // [b, h, 1, e]
                                                  const float* __restrict__ past_kv,  // [b, h, d, e]
                                                  const float* __restrict__ slope,    // [h, 1, 1]
                                                  T* __restrict__ output,             // [b, h, 1, e]
                                                  float* __restrict__ new_kv,         // [b, h, d, e]
                                                  const int batch_size, const int num_heads, const int qk_dim,
                                                  const int v_dim) {
  extern __shared__ char smem[];
  T* q_shared = reinterpret_cast<T*>(smem);
  T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
  T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
  float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
  T* output_shared =
      reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));

  const int32_t tid = threadIdx.x;
  const int32_t current_head = blockIdx.x;
  const int32_t b = current_head / num_heads;
  const int32_t h = current_head % num_heads;

  if (b >= batch_size) return;

  const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
  const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
  const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;

  for (int d = tid; d < qk_dim; d += blockDim.x) {
    q_shared[d] = q[qk_offset + d];
    k_shared[d] = k[qk_offset + d];
  }
  for (int e = tid; e < v_dim; e += blockDim.x) {
    v_shared[e] = v[v_offset + e];
  }

  __syncthreads();

  const float ratio = expf(-1.0f * slope[h]);

  for (int d = tid; d < qk_dim; d += blockDim.x) {
    T k_val = k_shared[d];
    for (int e = 0; e < v_dim; ++e) {
      int past_kv_idx = kv_offset + d * v_dim + e;
      T v_val = v_shared[e];
      float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
      int shared_idx = d * (v_dim + 1) + e;
      new_kv_shared[shared_idx] = new_val;
    }
  }

  __syncthreads();

  for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
    int d = idx / v_dim;
    int e = idx % v_dim;
    int shared_idx = d * (v_dim + 1) + e;
    int global_idx = kv_offset + idx;
    new_kv[global_idx] = new_kv_shared[shared_idx];
  }

  __syncthreads();

  for (int e = tid; e < v_dim; e += blockDim.x) {
    float sum = 0.0f;
    for (int d = 0; d < qk_dim; ++d) {
      int shared_idx = d * (v_dim + 1) + e;
      sum += q_shared[d] * new_kv_shared[shared_idx];
    }
    output_shared[e] = static_cast<T>(sum);
  }

  __syncthreads();

  if (tid == 0) {
    for (int e = 0; e < v_dim; ++e) {
      output[v_offset + e] = output_shared[e];
    }
  }
}

void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
                                const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
                                torch::Tensor new_kv) {
  TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
  TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
  TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
  TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");

  auto batch_size = q.size(0);
  auto num_heads = q.size(1);
  auto qk_dim = q.size(3);
  auto v_dim = v.size(3);

  dim3 block(THREADS_PER_BLOCK);
  dim3 grid(batch_size * num_heads);

  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND2(
      at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
        size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
        lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
            q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(),
            slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads,
            qk_dim, v_dim);
      }));
}