Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4405f82c
Commit
4405f82c
authored
Aug 20, 2024
by
zhuwenwen
Browse files
Refactoring the optimized kernel
parent
3a6764a4
Changes
18
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
2478 additions
and
725 deletions
+2478
-725
CMakeLists.txt
CMakeLists.txt
+4
-1
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+14
-92
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+293
-371
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+31
-217
csrc/ops.h
csrc/ops.h
+35
-2
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+228
-0
csrc/opt/attention_kernels_opt.cu
csrc/opt/attention_kernels_opt.cu
+1073
-0
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+538
-0
csrc/opt/static_switch.h
csrc/opt/static_switch.h
+0
-0
csrc/opt/transpose_kernels.cu
csrc/opt/transpose_kernels.cu
+0
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+56
-4
vllm/_custom_ops.py
vllm/_custom_ops.py
+87
-5
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+69
-21
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+13
-3
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+24
-7
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-1
No files found.
CMakeLists.txt
View file @
4405f82c
...
...
@@ -155,7 +155,10 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/opt/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
...
...
csrc/activation_kernels.cu
View file @
4405f82c
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
...
...
@@ -24,60 +23,6 @@ __global__ void act_and_mul_kernel(
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize1
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
if
(
idx
<
d
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize2
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
for
(;
idx
<
d
;
idx
+=
blockDim
.
x
*
VEC
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
...
...
@@ -109,6 +54,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
}
// namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
...
...
@@ -118,33 +64,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void
silu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
...
...
csrc/attention/attention_kernels.cu
View file @
4405f82c
This diff is collapsed.
Click to expand it.
csrc/layernorm_kernels.cu
View file @
4405f82c
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
...
...
@@ -291,168 +288,22 @@ fused_add_rms_norm_kernel(
}
// namespace vllm
template
<
typename
T
,
int
reducesize
=
C10_WARP_SIZE
>
__inline__
__device__
T
WarpReduceSum_NEW
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
WARP_SHFL_DOWN
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
int
block_size
=
512
>
__inline__
__device__
T
BlockReduceSum_NEW
(
T
val
,
T
*
shared
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
val
=
WarpReduceSum_NEW
<
T
>
(
val
);
if
constexpr
(
block_size
==
C10_WARP_SIZE
)
{
return
val
;
}
else
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
if
(
lid
==
0
&&
wid
<
share_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
share_size
)
{
val
=
WarpReduceSum_NEW
<
T
,
share_size
>
(
shared
[
lid
]);
}
return
val
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_eval
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_rms_kernel_eval
(
scalar_t
*
input
,
scalar_t
*
output
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
output
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
data_ptr
<
scalar_t
>
();
scalar_t
*
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
...
...
@@ -465,55 +316,13 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
num_tokens, hidden_size); \
});
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
data_ptr
<
scalar_t
>
();
scalar_t
*
other_data
=
residual
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
...
...
@@ -521,6 +330,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
...
...
@@ -528,11 +339,14 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
\ No newline at end of file
csrc/ops.h
View file @
4405f82c
...
...
@@ -23,12 +23,39 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
kv_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int64_t
block_size
,
int64_t
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
double
kv_scale
,
const
int64_t
tp_rank
,
const
int64_t
blocksparse_local_blocks
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rms_norm_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
...
...
@@ -56,6 +83,14 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
...
...
@@ -119,8 +154,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
...
...
csrc/opt/activation_kernels_opt.cu
0 → 100644
View file @
4405f82c
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace
vllm
{
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
act_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
)
*
y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize1
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
if
(
idx
<
d
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_vectorize2
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
using
VecType
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
VEC
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
*
VEC
;
for
(;
idx
<
d
;
idx
+=
blockDim
.
x
*
VEC
)
{
const
int64_t
x_index
=
token_idx
*
2
*
d
+
idx
;
const
int64_t
y_index
=
token_idx
*
d
+
idx
;
VecType
*
x1
=
(
VecType
*
)(
input
+
x_index
);
VecType
*
x2
=
(
VecType
*
)(
input
+
x_index
+
d
);
VecType
*
y
=
(
VecType
*
)(
out
+
y_index
);
scalar_t
r_x1
[
VEC
];
scalar_t
r_x2
[
VEC
];
scalar_t
r_y
[
VEC
];
*
(
VecType
*
)
r_x1
=
*
x1
;
*
(
VecType
*
)
r_x2
=
*
x2
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
i
++
)
{
r_y
[
i
]
=
ACT_FN
(
r_x1
[
i
])
*
r_x2
[
i
];
}
*
y
=
*
(
VecType
*
)
r_y
;
}
}
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const
float
f
=
(
float
)
x
;
constexpr
float
ALPHA
=
M_SQRT1_2
;
return
(
T
)(
f
*
0.5
f
*
(
1.0
f
+
::
erf
(
f
*
ALPHA
)));
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_tanh_kernel
(
const
T
&
x
)
{
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const
float
f
=
(
float
)
x
;
constexpr
float
BETA
=
M_SQRT2
*
M_2_SQRTPI
*
0.5
f
;
constexpr
float
KAPPA
=
0.044715
;
float
x_cube
=
f
*
f
*
f
;
float
inner
=
BETA
*
(
f
+
KAPPA
*
x_cube
);
return
(
T
)(
0.5
f
*
f
*
(
1.0
f
+
::
tanhf
(
inner
)));
}
}
// namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void
silu_and_mul_opt
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
void
gelu_and_mul_opt
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_kernel
);
}
void
gelu_tanh_and_mul_opt
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
namespace
vllm
{
// Element-wise activation kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
activation_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
);
}
}
}
// namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
gelu_new_kernel
(
const
T
&
x
)
{
const
float
x3
=
(
float
)(
x
*
x
*
x
);
const
T
t
=
(
T
)
tanhf
((
T
)(
0.79788456
f
*
(
float
)(
x
+
(
T
)(
0.044715
f
*
x3
))));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_fast_kernel
(
const
T
&
x
)
{
const
float
f
=
(
float
)
x
;
const
T
t
=
(
T
)
tanhf
(((
T
)(
f
*
0.79788456
f
))
*
(((
T
)
1.0
)
+
(
T
)(
0.044715
f
*
f
)
*
x
));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
}
// namespace vllm
void
gelu_new
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
}
void
gelu_fast
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
}
csrc/opt/attention_kernels_opt.cu
0 → 100644
View file @
4405f82c
This diff is collapsed.
Click to expand it.
csrc/opt/layernorm_kernels_opt.cu
0 → 100644
View file @
4405f82c
This diff is collapsed.
Click to expand it.
csrc/
attention
/static_switch.h
→
csrc/
opt
/static_switch.h
View file @
4405f82c
File moved
csrc/transpose_kernels.cu
→
csrc/
opt/
transpose_kernels.cu
View file @
4405f82c
File moved
csrc/torch_bindings.cpp
View file @
4405f82c
...
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
// Activation ops
// Activation function used in SwiGLU.
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
...
...
@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCUDA
,
&
gelu_fast
);
// Activation function used in SwiGLU. (opt)
ops
.
def
(
"silu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"silu_and_mul_opt"
,
torch
::
kCUDA
,
&
silu_and_mul_opt
);
// Activation function used in GeGLU with `none` approximation. (opt)
ops
.
def
(
"gelu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_and_mul_opt
);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops
.
def
(
"gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_tanh_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul_opt
);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
...
...
@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops
.
def
(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm_opt"
,
torch
::
kCUDA
,
&
rms_norm_opt
);
// In-place fused Add and RMS Normalization. (opt)
ops
.
def
(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm_opt"
,
torch
::
kCUDA
,
&
fused_add_rms_norm_opt
);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
...
...
@@ -108,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
...
...
@@ -159,10 +215,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantized GEMM for SqueezeLLM.
ops
.
def
(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
...
...
vllm/_custom_ops.py
View file @
4405f82c
...
...
@@ -40,6 +40,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
def
silu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul_opt
(
out
,
x
)
def
gelu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_and_mul_opt
(
out
,
x
)
def
gelu_tanh_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_tanh_and_mul_opt
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
...
...
@@ -107,6 +119,65 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# page attention ops (opt)
def
paged_attention_v1_opt
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2_opt
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# pos encoding ops
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
...
...
@@ -140,6 +211,21 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# layer norm ops (opt)
def
rms_norm_opt
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
rms_norm_opt
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm_opt
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm_opt
(
input
,
residual
,
weight
,
epsilon
)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# quantization ops
# awq
...
...
@@ -207,10 +293,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
quant_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
4405f82c
...
...
@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>=
8000
:
if
prefill_meta
.
max_prefill_seq_len
>=
4096
:
out
=
self
.
attn_func_triton
(
q
=
query
,
k
=
key
,
...
...
vllm/attention/ops/paged_attn.py
View file @
4405f82c
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
import
vllm.envs
as
envs
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
...
...
@@ -122,6 +123,28 @@ class PagedAttention:
if
use_v1
:
# Run PagedAttention V1.
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
paged_attention_v1_opt
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v1
(
output
,
query
,
...
...
@@ -156,6 +179,31 @@ class PagedAttention:
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
...
...
vllm/envs.py
View file @
4405f82c
...
...
@@ -10,6 +10,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
USE_VLLM_OPT_OP
:
bool
=
False
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
...
@@ -139,6 +140,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
"USE_VLLM_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"USE_VLLM_OPT_OP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK"
:
...
...
vllm/model_executor/layers/activation.py
View file @
4405f82c
...
...
@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
import
vllm.envs
as
envs
class
SiluAndMul
(
CustomOp
):
...
...
@@ -34,6 +35,9 @@ class SiluAndMul(CustomOp):
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
silu_and_mul
(
out
,
x
)
else
:
ops
.
silu_and_mul
(
out
,
x
)
return
out
...
...
@@ -66,8 +70,14 @@ class GeluAndMul(CustomOp):
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
approximate
==
"none"
:
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
gelu_and_mul_opt
(
out
,
x
)
else
:
ops
.
gelu_and_mul
(
out
,
x
)
elif
self
.
approximate
==
"tanh"
:
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
gelu_tanh_and_mul_opt
(
out
,
x
)
else
:
ops
.
gelu_tanh_and_mul
(
out
,
x
)
return
out
...
...
vllm/model_executor/layers/layernorm.py
View file @
4405f82c
...
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
import
vllm.envs
as
envs
class
RMSNorm
(
CustomOp
):
...
...
@@ -51,6 +52,14 @@ class RMSNorm(CustomOp):
from
vllm
import
_custom_ops
as
ops
if
residual
is
not
None
:
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
fused_add_rms_norm_opt
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
else
:
ops
.
fused_add_rms_norm
(
x
,
residual
,
...
...
@@ -59,6 +68,14 @@ class RMSNorm(CustomOp):
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
rms_norm_opt
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
else
:
ops
.
rms_norm
(
out
,
x
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
4405f82c
...
...
@@ -23,6 +23,7 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
]
use_triton_fa_architectures
=
[
'DeepseekV2ForCausalLM'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
...
...
@@ -35,6 +36,10 @@ def get_model_architecture(
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
if
any
(
arch
in
architectures
for
arch
in
use_triton_fa_architectures
):
os
.
environ
[
'VLLM_USE_TRITON_FLASH_ATTN'
]
=
'1'
os
.
environ
[
'VLLM_USE_FLASH_ATTN_AUTO'
]
=
'0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment