"magic_pdf/pdf_parse_by_txt.py" did not exist on "fcea39d36b323de22af7161e5aa90e9b1b1affbd"
activation_kernels_impl.cuh 3.55 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include "utils.cuh"
#include "reduction_utils.cuh"

namespace vllm {

template <typename T> __device__ __forceinline__ T silu(const T &x) {
  // x * sigmoid(x)
  return (T)(((float)x) / (1.0f + expf((float)-x)));
}

  
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
  scalar_t* __restrict__ out,               // [..., d]
  const scalar_t* __restrict__ input,       // [..., 2 * d]
  const int d) {

  const int token_idx = blockIdx.x;
  const int64_t token_idx_d = token_idx * int64_t(d);
  const int64_t token_idx_2d = token_idx_d * 2;
  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
    const scalar_t x = __ldg(&input[token_idx_2d + idx]);
    const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
    out[token_idx_d + idx] = silu(x) * y;
  }
}

// dequant int32 input, apply silu and mul, then per token quant to int8
template <typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(
    int8_t *__restrict__ out,          // [..., d]
    const int32_t *__restrict__ input, // [..., 2 * d]
    const int d, const float scale_gate, const float scale_up,
    scale_type scale_out,                  // [num_tokens]
    float *__restrict__ tmp = nullptr // [num_tokens, d]
) {
  const int token_idx = blockIdx.x;
  if constexpr (use_per_token_quant) {
    float amax_val = 0.0f;
    const float zero = 0.0f;

    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
      const float x =
          (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
      const float y =
          (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
      float t = silu(x) * y;
      tmp[token_idx * d + idx] = t;
      t = t > zero ? t : -t;
      if (t > amax_val)
        amax_val = t;
    }

    __shared__ float s_amax;
    const float block_amax_val = blockReduceMax(amax_val);
    if (threadIdx.x == 0) {
      s_amax = block_amax_val;
      scale_out[token_idx] = block_amax_val / 127.0f;
    }
    __syncthreads();

    float tmp_scale = 127.0f / s_amax;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
      out[token_idx * d + idx] =
          float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
    }
  } else {
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
      const float x =
          (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
      const float y =
          (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
      out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
    }
  }
}
} // namespace vllm



namespace vllm {

// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
  scalar_t* __restrict__ out,               // [..., d]
  const scalar_t* __restrict__ input,       // [..., d]
  const int d) {
  const int token_idx = blockIdx.x;
  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
    const scalar_t x = __ldg(&input[token_idx * d + idx]);
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

} // namespace vllm



namespace vllm {

template <typename T> __device__ __forceinline__ T gelu_new_kernel(const T &x) {
  const float x3 = (float)(x * x * x);
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  return ((T)0.5) * x * (((T)1.0) + t);
}

template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
  const float f = (float)x;
  const T t =
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  return ((T)0.5) * x * (((T)1.0) + t);
}

} // namespace vllm