"test/gtest-1.8.0/googlemock/scripts/generator/cpp/ast.py" did not exist on "f6a6f46ba47d8c45934870c4a37d6121818eb93f"
lightning_attention_decode_kernel.cu 5.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
19
20
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
21
#include <torch/all.h>
22
23
24
25

#define THREADS_PER_BLOCK 128

template <typename T>
26
27
28
29
30
31
32
33
34
35
36
37
__global__ void lightning_attention_decode_kernel(
    const T* __restrict__ q,            // [b, h, 1, d]
    const T* __restrict__ k,            // [b, h, 1, d]
    const T* __restrict__ v,            // [b, h, 1, e]
    const float* __restrict__ past_kv,  // [b, h, d, e]
    const float* __restrict__ slope,    // [h, 1, 1]
    T* __restrict__ output,             // [b, h, 1, e]
    float* __restrict__ new_kv,         // [b, h, d, e]
    const int batch_size,
    const int num_heads,
    const int qk_dim,
    const int v_dim) {
38
  extern __shared__ char smem[];
39
40
41
42
43
  T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
  T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
  T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
  float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
  T* __restrict__ output_shared =
44
45
46
47
48
49
50
51
52
53
54
55
56
      reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));

  const int32_t tid = threadIdx.x;
  const int32_t current_head = blockIdx.x;
  const int32_t b = current_head / num_heads;
  const int32_t h = current_head % num_heads;

  if (b >= batch_size) return;

  const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
  const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
  const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;

57
  // Load q, k, v into shared memory
58
59
60
61
62
63
64
65
66
67
68
69
  for (int d = tid; d < qk_dim; d += blockDim.x) {
    q_shared[d] = q[qk_offset + d];
    k_shared[d] = k[qk_offset + d];
  }
  for (int e = tid; e < v_dim; e += blockDim.x) {
    v_shared[e] = v[v_offset + e];
  }

  __syncthreads();

  const float ratio = expf(-1.0f * slope[h]);

70
  // Compute new_kv
71
  for (int d = tid; d < qk_dim; d += blockDim.x) {
72
    const T k_val = k_shared[d];
73
    for (int e = 0; e < v_dim; ++e) {
74
75
76
77
      const int past_kv_idx = kv_offset + d * v_dim + e;
      const T v_val = v_shared[e];
      const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
      const int shared_idx = d * (v_dim + 1) + e;
78
79
80
81
82
83
      new_kv_shared[shared_idx] = new_val;
    }
  }

  __syncthreads();

84
  // Store new_kv to global memory
85
  for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
86
87
88
89
    const int d = idx / v_dim;
    const int e = idx % v_dim;
    const int shared_idx = d * (v_dim + 1) + e;
    const int global_idx = kv_offset + idx;
90
91
92
93
94
    new_kv[global_idx] = new_kv_shared[shared_idx];
  }

  __syncthreads();

95
  // Compute output
96
97
98
  for (int e = tid; e < v_dim; e += blockDim.x) {
    float sum = 0.0f;
    for (int d = 0; d < qk_dim; ++d) {
99
      const int shared_idx = d * (v_dim + 1) + e;
100
101
102
103
104
105
106
      sum += q_shared[d] * new_kv_shared[shared_idx];
    }
    output_shared[e] = static_cast<T>(sum);
  }

  __syncthreads();

107
  // Store output to global memory
108
109
110
111
112
113
114
  if (tid == 0) {
    for (int e = 0; e < v_dim; ++e) {
      output[v_offset + e] = output_shared[e];
    }
  }
}

115
116
117
118
119
120
121
122
void lightning_attention_decode(
    const torch::Tensor& q,
    const torch::Tensor& k,
    const torch::Tensor& v,
    const torch::Tensor& past_kv,
    const torch::Tensor& slope,
    torch::Tensor output,
    torch::Tensor new_kv) {
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
  TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
  TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
  TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");

  auto batch_size = q.size(0);
  auto num_heads = q.size(1);
  auto qk_dim = q.size(3);
  auto v_dim = v.size(3);

  dim3 block(THREADS_PER_BLOCK);
  dim3 grid(batch_size * num_heads);

  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND2(
      at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
        size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
        lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
142
143
144
145
146
147
148
149
150
151
152
            q.data_ptr<scalar_t>(),
            k.data_ptr<scalar_t>(),
            v.data_ptr<scalar_t>(),
            past_kv.data_ptr<float>(),
            slope.data_ptr<float>(),
            output.data_ptr<scalar_t>(),
            new_kv.data_ptr<float>(),
            batch_size,
            num_heads,
            qk_dim,
            v_dim);
153
154
      }));
}