Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0aa2480f
Commit
0aa2480f
authored
Aug 21, 2024
by
zhuwenwen
Browse files
Refactoring the optimized kernel
parent
2f9e0bad
Changes
14
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
2416 additions
and
533 deletions
+2416
-533
CMakeLists.txt
CMakeLists.txt
+4
-1
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+14
-93
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+295
-371
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+1078
-0
csrc/ops.h
csrc/ops.h
+35
-2
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+166
-0
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+538
-0
csrc/opt/transpose_kernels_opt.cu
csrc/opt/transpose_kernels_opt.cu
+0
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+56
-4
vllm/_custom_ops.py
vllm/_custom_ops.py
+90
-4
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+94
-45
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+13
-3
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+27
-10
No files found.
CMakeLists.txt
View file @
0aa2480f
...
@@ -181,7 +181,10 @@ set(VLLM_EXT_SRC
...
@@ -181,7 +181,10 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/attention/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
...
...
csrc/activation_kernels.cu
View file @
0aa2480f
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
#include <cmath>
...
@@ -24,60 +23,6 @@ __global__ void act_and_mul_kernel(
...
@@ -24,60 +23,6 @@ __global__ void act_and_mul_kernel(
}
}
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize1
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
if
(
idx
<
d
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize2
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
for
(;
idx
<
d
;
idx
+=
blockDim
.
x
*
VEC
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
// x * sigmoid(x)
...
@@ -109,6 +54,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -109,6 +54,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
}
// namespace vllm
}
// namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
...
@@ -118,33 +64,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -118,33 +64,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} \
});
});
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
...
@@ -238,4 +160,3 @@ void gelu_quick(torch::Tensor& out, // [..., d]
...
@@ -238,4 +160,3 @@ void gelu_quick(torch::Tensor& out, // [..., d]
{
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_quick_kernel
);
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_quick_kernel
);
}
}
\ No newline at end of file
csrc/attention/attention_kernels.cu
View file @
0aa2480f
This diff is collapsed.
Click to expand it.
csrc/attention/attention_kernels_opt.cu
0 → 100644
View file @
0aa2480f
This diff is collapsed.
Click to expand it.
csrc/ops.h
View file @
0aa2480f
...
@@ -26,12 +26,39 @@ void paged_attention_v2(
...
@@ -26,12 +26,39 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
k_scale
,
double
v_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
double
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rms_norm_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
...
@@ -55,12 +82,20 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
...
@@ -55,12 +82,20 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_quick
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_quick
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
...
@@ -151,8 +186,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
...
@@ -151,8 +186,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
// torch::Tensor const& scale);
...
...
csrc/opt/activation_kernels_opt.cu
0 → 100644
View file @
0aa2480f
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
#include "cuda_compat.h"
#include "../dispatch_utils.h"
namespace
vllm
{
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
act_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
)
*
y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_opt1
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
if
(
idx
<
d
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_opt2
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
for
(;
idx
<
d
;
idx
+=
blockDim
.
x
*
VEC
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const
float
f
=
(
float
)
x
;
constexpr
float
ALPHA
=
M_SQRT1_2
;
return
(
T
)(
f
*
0.5
f
*
(
1.0
f
+
::
erf
(
f
*
ALPHA
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_tanh_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const
float
f
=
(
float
)
x
;
constexpr
float
BETA
=
M_SQRT2
*
M_2_SQRTPI
*
0.5
f
;
constexpr
float
KAPPA
=
0.044715
;
float
x_cube
=
f
*
f
*
f
;
float
inner
=
BETA
*
(
f
+
KAPPA
*
x_cube
);
return
(
T
)(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
}
}
// namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_opt2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void
silu_and_mul_opt
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
void
gelu_and_mul_opt
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
}
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
\ No newline at end of file
csrc/opt/layernorm_kernels_opt.cu
0 → 100644
View file @
0aa2480f
This diff is collapsed.
Click to expand it.
csrc/transpose_kernels.cu
→
csrc/
opt/
transpose_kernels
_opt
.cu
View file @
0aa2480f
File moved
csrc/torch_bindings.cpp
View file @
0aa2480f
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCUDA
,
&
paged_attention_v1
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Activation ops
// Activation ops
// Activation function used in SwiGLU.
// Activation function used in SwiGLU.
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
...
@@ -60,6 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -60,6 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_tanh_and_mul"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul
);
ops
.
impl
(
"gelu_tanh_and_mul"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul
);
// Activation function used in SwiGLU. (opt)
ops
.
def
(
"silu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"silu_and_mul_opt"
,
torch
::
kCUDA
,
&
silu_and_mul
);
// Activation function used in GeGLU with `none` approximation. (opt)
ops
.
def
(
"gelu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_and_mul
);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops
.
def
(
"gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_tanh_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul
);
// GELU implementation used in GPT-2.
// GELU implementation used in GPT-2.
ops
.
def
(
"gelu_new(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"gelu_new(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_new"
,
torch
::
kCUDA
,
&
gelu_new
);
ops
.
impl
(
"gelu_new"
,
torch
::
kCUDA
,
&
gelu_new
);
...
@@ -89,6 +129,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -89,6 +129,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops
.
def
(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm_opt"
,
torch
::
kCUDA
,
&
rms_norm_opt
);
// In-place fused Add and RMS Normalization. (opt)
ops
.
def
(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm_opt"
,
torch
::
kCUDA
,
&
fused_add_rms_norm_opt
);
// Rotary embedding
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
ops
.
def
(
...
@@ -116,6 +168,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -116,6 +168,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"
);
" Tensor cos_sin_cache_offsets) -> ()"
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantization ops
// Quantization ops
#ifndef USE_ROCM
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
// Quantized GEMM for AQLM.
...
@@ -185,10 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -185,10 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantized GEMM for SqueezeLLM.
// Quantized GEMM for SqueezeLLM.
ops
.
def
(
ops
.
def
(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
...
...
vllm/_custom_ops.py
View file @
0aa2480f
...
@@ -61,6 +61,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
...
@@ -61,6 +61,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
def
silu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul_opt
(
out
,
x
)
def
gelu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_and_mul_opt
(
out
,
x
)
def
gelu_tanh_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_tanh_and_mul_opt
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
...
@@ -135,6 +147,68 @@ def paged_attention_v2(
...
@@ -135,6 +147,68 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# page attention ops (opt)
def
paged_attention_v1_opt
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2_opt
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# pos encoding ops
# pos encoding ops
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -169,6 +243,17 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
...
@@ -169,6 +243,17 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# layer norm ops (opt)
def
rms_norm_opt
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
rms_norm_opt
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm_opt
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm_opt
(
input
,
residual
,
weight
,
epsilon
)
def
advance_step
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
def
advance_step
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
...
@@ -180,6 +265,11 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
...
@@ -180,6 +265,11 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_positions
,
seq_lens
,
slot_mapping
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
)
block_tables
)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# quantization ops
# quantization ops
# awq
# awq
...
@@ -247,10 +337,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
...
@@ -247,10 +337,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
quant_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
quant_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# squeezellm
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
...
...
vllm/attention/ops/paged_attn.py
View file @
0aa2480f
...
@@ -134,6 +134,29 @@ class PagedAttention:
...
@@ -134,6 +134,29 @@ class PagedAttention:
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
paged_attention_v1_opt
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
output
,
output
,
query
,
query
,
...
@@ -176,6 +199,32 @@ class PagedAttention:
...
@@ -176,6 +199,32 @@ class PagedAttention:
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
output
,
output
,
exp_sums
,
exp_sums
,
...
...
vllm/envs.py
View file @
0aa2480f
...
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
...
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
...
@@ -188,6 +189,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -188,6 +189,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control if vllm print pa parameters
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM"
:
"VLLM_USE_PA_PRINT_PARAM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
...
...
vllm/model_executor/layers/activation.py
View file @
0aa2480f
...
@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
import
vllm.envs
as
envs
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
...
@@ -34,6 +35,9 @@ class SiluAndMul(CustomOp):
...
@@ -34,6 +35,9 @@ class SiluAndMul(CustomOp):
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
silu_and_mul_opt
(
out
,
x
)
else
:
ops
.
silu_and_mul
(
out
,
x
)
ops
.
silu_and_mul
(
out
,
x
)
return
out
return
out
...
@@ -75,8 +79,14 @@ class GeluAndMul(CustomOp):
...
@@ -75,8 +79,14 @@ class GeluAndMul(CustomOp):
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
approximate
==
"none"
:
if
self
.
approximate
==
"none"
:
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
gelu_and_mul_opt
(
out
,
x
)
else
:
ops
.
gelu_and_mul
(
out
,
x
)
ops
.
gelu_and_mul
(
out
,
x
)
elif
self
.
approximate
==
"tanh"
:
elif
self
.
approximate
==
"tanh"
:
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
gelu_tanh_and_mul_opt
(
out
,
x
)
else
:
ops
.
gelu_tanh_and_mul
(
out
,
x
)
ops
.
gelu_tanh_and_mul
(
out
,
x
)
return
out
return
out
...
...
vllm/model_executor/layers/layernorm.py
View file @
0aa2480f
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
import
vllm.envs
as
envs
class
RMSNorm
(
CustomOp
):
class
RMSNorm
(
CustomOp
):
...
@@ -51,6 +52,14 @@ class RMSNorm(CustomOp):
...
@@ -51,6 +52,14 @@ class RMSNorm(CustomOp):
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
if
residual
is
not
None
:
if
residual
is
not
None
:
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
fused_add_rms_norm_opt
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
else
:
ops
.
fused_add_rms_norm
(
ops
.
fused_add_rms_norm
(
x
,
x
,
residual
,
residual
,
...
@@ -59,6 +68,14 @@ class RMSNorm(CustomOp):
...
@@ -59,6 +68,14 @@ class RMSNorm(CustomOp):
)
)
return
x
,
residual
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
rms_norm_opt
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
else
:
ops
.
rms_norm
(
ops
.
rms_norm
(
out
,
out
,
x
,
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment