activation_kernels.cu 3.67 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include "activation_kernels_impl.cuh"
#include "activation_kernels.h"
#include "dispatch_utils.h"

// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
  int d = input.size(-1);                                                                 \
  int num_tokens = input.numel() / d;                                                     \
  dim3 grid(num_tokens);                                                                  \
  dim3 block(std::min(d, 1024));                                                          \
  const cudaStream_t stream = getCurrentCUDAStream();                           \
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
    input.scalar_type(),                                                                  \
    "activation_kernel",                                                                  \
    [&] {                                                                                 \
      vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \
        out.data_ptr<scalar_t>(),                                                         \
        input.data_ptr<scalar_t>(),                                                       \
        d);                                                                               \
    });

void silu_and_mul(
  Tensor& out,      // [..., d]
  Tensor& input)    // [..., 2 * d]
{
  int64_t num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1) / 2;
  dim3 grid(num_tokens);
  dim3 block(std::min(d, 1024));
  const cudaStream_t stream = getCurrentCUDAStream();
//   dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
//     vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
//         out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
//   });
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
    vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
  });
}

void invoke_dequant_silu_and_mul_quant(
    Tensor &out,   // [..., d]
    Tensor &input, // [..., 2 * d]
    const float scale_gate, const float scale_up, const float scale_out) {
  int64_t num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1) / 2;
  dim3 grid(num_tokens);
  dim3 block(std::min(d, 1024));
  const cudaStream_t stream = getCurrentCUDAStream();
  vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
      out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate,
      scale_up, scale_out);
}


void invoke_dequant_silu_and_mul_quant(
    Tensor &out,   // [..., d]
    Tensor &input, // [..., 2 * d]
    const float scale_gate, const float scale_up,
    Tensor &scale_out, // [num_tokens]
    Tensor &tmp // [..., d]
) {
  int64_t num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1) / 2;
  dim3 grid(num_tokens);
  dim3 block(std::min(d, 1024));
  const cudaStream_t stream = getCurrentCUDAStream();
  vllm::dequant_silu_and_mul_quant_kernel<float*, true><<<grid, block, 0, stream>>>(
      out.data_ptr<int8_t>(), input.data_ptr<int32_t>(),
       d, scale_gate, scale_up, scale_out.data_ptr<float>(), tmp.data_ptr<float>());
}

void silu(
  Tensor& out,     // [..., d]
  Tensor& input)   // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::silu);
}

void gelu_new(
  Tensor& out,     // [..., d]
  Tensor& input)   // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}


void gelu_fast(
  Tensor& out,     // [..., d]
  Tensor& input)   // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}