Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71b1be50
Commit
71b1be50
authored
Jul 18, 2024
by
zhuwenwen
Browse files
back to pa and rn
parent
c628c6ec
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
754 deletions
+79
-754
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+39
-497
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+6
-67
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+34
-190
No files found.
csrc/attention/attention_kernels.cu
View file @
71b1be50
This diff is collapsed.
Click to expand it.
csrc/attention/attention_utils.cuh
View file @
71b1be50
...
...
@@ -26,75 +26,19 @@
namespace
vllm
{
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint32_t
&
b
,
const
uint32_t
&
c
)
{
asm
volatile
(
"v_dot2_f32_f16 %0, %1, %2, %0;"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
inline
__device__
void
v_pk_fma_f16
(
uint32_t
&
a
,
const
uint32_t
&
b
,
const
uint32_t
&
c
){
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3;"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"v"
(
a
));
}
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint2
&
b
,
const
uint2
&
c
)
{
v_dot2_f32_f16
(
a
,
b
.
x
,
c
.
x
);
v_dot2_f32_f16
(
a
,
b
.
y
,
c
.
y
);
}
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint4
&
b
,
const
uint4
&
c
)
{
v_dot2_f32_f16
(
a
,
b
.
x
,
c
.
x
);
v_dot2_f32_f16
(
a
,
b
.
y
,
c
.
y
);
v_dot2_f32_f16
(
a
,
b
.
z
,
c
.
z
);
v_dot2_f32_f16
(
a
,
b
.
w
,
c
.
w
);
}
inline
__device__
float
add_half2
(
uint32_t
a
){
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u32
=
a
;
return
static_cast
<
float
>
(
tmp
.
u16
[
0
]
+
tmp
.
u16
[
1
]);
}
inline
__device__
void
v_pk_fma_f16x8
(
float
&
a
,
const
uint4
&
b
,
const
uint4
&
c
)
{
uint32_t
tmp
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
b
.
x
,
c
.
x
);
v_pk_fma_f16
(
tmp
,
b
.
y
,
c
.
y
);
v_pk_fma_f16
(
tmp
,
b
.
z
,
c
.
z
);
v_pk_fma_f16
(
tmp
,
b
.
w
,
c
.
w
);
a
+=
add_half2
(
tmp
);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
// Q*K^T operation.
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
float
qk
=
0
;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
v_dot2_f32_f16
(
qk
,
q
[
ii
],
k
[
ii
]);
}
// Finalize the reduction across lanes.
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
}
return
qk
;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_v1
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
float
qk
=
sum
(
qk_vec
);
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
...
...
@@ -102,17 +46,12 @@ inline __device__ float qk_dot_v1(const Vec (&q)[N], const Vec (&k)[N]) {
return
qk
;
}
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
struct
Qk_dot
{
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot_v1
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_v1
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
};
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/layernorm_kernels.cu
View file @
71b1be50
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
...
...
@@ -291,149 +288,22 @@ fused_add_rms_norm_kernel(
}
// namespace vllm
template
<
typename
T
,
int
reducesize
=
C10_WARP_SIZE
>
__inline__
__device__
T
WarpReduceSum_NEW
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
WARP_SHFL_DOWN
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
int
block_size
=
512
>
__inline__
__device__
T
BlockReduceSum_NEW
(
T
val
,
T
*
shared
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
val
=
WarpReduceSum_NEW
<
T
>
(
val
);
if
constexpr
(
block_size
==
C10_WARP_SIZE
)
{
return
val
;
}
else
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
__syncthreads
();
if
(
lid
==
0
&&
wid
<
share_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
share_size
)
{
val
=
WarpReduceSum_NEW
<
T
,
share_size
>
(
shared
[
lid
]);
}
return
val
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_eval
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
if
(
j
>=
tcol
)
return
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_rms_kernel_eval
(
scalar_t
*
input
,
scalar_t
*
output
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
if
(
j
>=
tcol
)
return
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
output
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
data_ptr
<
scalar_t
>
();
scalar_t
*
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
if
(
hidden_size
==
2048
){
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
2
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
...
...
@@ -446,63 +316,37 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
num_tokens, hidden_size); \
});
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
data_ptr
<
scalar_t
>
();
scalar_t
*
other_data
=
residual
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
if
(
hidden_size
==
2048
){
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
2
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_eval
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
});
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
else
{
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment