Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90ddfba8
Commit
90ddfba8
authored
Jan 21, 2026
by
zhuwenwen
Browse files
fix fa error and remove layernorm kernel
parent
7f7894c0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
8 additions
and
434 deletions
+8
-434
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+0
-413
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-12
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+4
-3
vllm/v1/attention/backends/fa_utils.py
vllm/v1/attention/backends/fa_utils.py
+2
-4
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-1
No files found.
CMakeLists.txt
View file @
90ddfba8
...
...
@@ -297,7 +297,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
# "csrc/opt/layernorm_kernels_opt.cu"
"csrc/fused_qknorm_rope_kernel.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
...
...
csrc/opt/layernorm_kernels_opt.cu
deleted
100644 → 0
View file @
7f7894c0
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
variance
+=
x
*
x
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
const
int64_t
vec_input_stride
=
input_stride
/
width
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
strided_id
];
temp
+=
residual_v
[
id
];
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
temp
*=
s_variance
;
temp
*=
weight_v
[
idx
];
input_v
[
strided_id
]
=
temp
;
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
input_stride
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
input_stride
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
// namespace vllm
template
<
typename
T
,
int
reducesize
=
C10_WARP_SIZE
>
__inline__
__device__
T
WarpReduceSum_NEW
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
WARP_SHFL_DOWN
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
int
block_size
=
512
>
__inline__
__device__
T
BlockReduceSum_NEW
(
T
val
,
T
*
shared
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
val
=
WarpReduceSum_NEW
<
T
>
(
val
);
if
constexpr
(
block_size
==
C10_WARP_SIZE
)
{
return
val
;
}
else
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
if
(
lid
==
0
&&
wid
<
share_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
share_size
)
{
val
=
WarpReduceSum_NEW
<
T
,
share_size
>
(
shared
[
lid
]);
}
return
val
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_opt
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
T_ACC
trstd
;
int64_t
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_rms_kernel_opt
(
scalar_t
*
input
,
scalar_t
*
output
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
T_ACC
trstd
;
int64_t
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
output
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
void
rms_norm_opt
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
out_data
=
out
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), input_stride, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
});
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int64_t
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
other_data
=
residual
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
constexpr
int
vector_width
=
8
;
constexpr
int
req_alignment_bytes
=
vector_width
*
2
;
// vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool
ptrs_are_aligned
=
inp_ptr
%
req_alignment_bytes
==
0
&&
res_ptr
%
req_alignment_bytes
==
0
&&
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
\ No newline at end of file
vllm/_custom_ops.py
View file @
90ddfba8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Literal
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
import
torch
...
...
@@ -348,17 +348,6 @@ def fused_add_rms_norm(
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# layer norm ops (opt)
# def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
# epsilon: float) -> None:
# torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
# def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
# weight: torch.Tensor, epsilon: float) -> None:
# torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def
fused_qk_norm_rope
(
qkv
:
torch
.
Tensor
,
num_heads_q
:
int
,
...
...
vllm/model_executor/layers/layernorm.py
View file @
90ddfba8
...
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.batch_invariant import (
)
from
vllm.platforms
import
current_platform
from
vllm
import
envs
import
lightop
as
op
def
rms_norm
(
...
...
@@ -27,10 +28,10 @@ def rms_norm(
out
=
torch
.
empty_like
(
x
)
# if envs.VLLM_USE_OPT_OP:
if
False
:
ops
.
rms_norm_opt
(
out
,
op
.
rmsnorm_forward
(
x
,
weight
,
out
,
variance_epsilon
,
)
else
:
...
...
@@ -57,7 +58,7 @@ def fused_add_rms_norm(
),
x
+
residual
# if envs.VLLM_USE_OPT_OP:
if
False
:
op
s
.
fused_add_rms_norm_opt
(
op
.
rn_add_forward_autograd
(
x
,
residual
,
weight
,
...
...
vllm/v1/attention/backends/fa_utils.py
View file @
90ddfba8
...
...
@@ -22,10 +22,8 @@ elif current_platform.is_xpu():
elif
current_platform
.
is_rocm
():
try
:
# from flash_attn import flash_attn_varlen_func # noqa: F401
from
vllm
import
_custom_ops
as
ops
from
vllm._custom_ops
import
reshape_and_cache_cuda
from
flash_attn
import
vllm_flash_attn_varlen_func
reshape_and_cache_cuda
=
ops
.
reshape_and_cache_cuda
except
ImportError
as
e
:
raise
ImportError
(
"Rocm platform requires upstream flash-attn "
...
...
@@ -41,7 +39,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
return
2
if
current_platform
.
is_rocm
():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return
None
return
2
#
None
try
:
from
vllm.vllm_flash_attn.flash_attn_interface
import
(
fa_version_unsupported_reason
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
90ddfba8
...
...
@@ -742,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_v_scale
)
else
:
from
vllm.attention.
util
s.fa_utils
import
reshape_and_cache_cuda
from
vllm.
v1.
attention.
backend
s.fa_utils
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment