Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8ce9c50d
Unverified
Commit
8ce9c50d
authored
Sep 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 02, 2023
Browse files
Avoid compiling kernels for double data type (#933)
parent
32b6816e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
21 deletions
+28
-21
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+4
-6
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+5
-9
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+14
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+2
-3
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+3
-3
No files found.
csrc/activation_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace
vllm
{
namespace
vllm
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -34,9 +36,7 @@ void silu_and_mul(
...
@@ -34,9 +36,7 @@ void silu_and_mul(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
"silu_and_mul_kernel"
,
[
&
]
{
[
&
]
{
...
@@ -71,9 +71,7 @@ __global__ void activation_kernel(
...
@@ -71,9 +71,7 @@ __global__ void activation_kernel(
dim3 grid(num_tokens); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
VLLM_DISPATCH_FLOATING_TYPES( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
input.scalar_type(), \
input.scalar_type(), \
"activation_kernel", \
"activation_kernel", \
[&] { \
[&] { \
...
...
csrc/cache_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <map>
#include <map>
...
@@ -125,9 +127,7 @@ void copy_blocks(
...
@@ -125,9 +127,7 @@ void copy_blocks(
dim3
grid
(
num_layers
,
num_pairs
);
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
...
@@ -202,9 +202,7 @@ void reshape_and_cache(
...
@@ -202,9 +202,7 @@ void reshape_and_cache(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key
.
scalar_type
(),
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
"reshape_and_cache_kernel"
,
[
&
]
{
[
&
]
{
...
@@ -364,9 +362,7 @@ void gather_cached_kv(
...
@@ -364,9 +362,7 @@ void gather_cached_kv(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key
.
scalar_type
(),
key
.
scalar_type
(),
"gather_cached_kv_kernel_optimized"
,
"gather_cached_kv_kernel_optimized"
,
[
&
]
{
[
&
]
{
...
...
csrc/dispatch_utils.h
0 → 100644
View file @
8ce9c50d
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#include "reduction_utils.cuh"
namespace
vllm
{
namespace
vllm
{
...
@@ -46,9 +47,7 @@ void rms_norm(
...
@@ -46,9 +47,7 @@ void rms_norm(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
input
.
scalar_type
(),
"rms_norm_kernel"
,
"rms_norm_kernel"
,
[
&
]
{
[
&
]
{
...
...
csrc/pos_encoding_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace
vllm
{
namespace
vllm
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -83,9 +85,7 @@ void rotary_embedding_neox(
...
@@ -83,9 +85,7 @@ void rotary_embedding_neox(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
query
.
scalar_type
(),
"rotary_embedding_neox"
,
"rotary_embedding_neox"
,
[
&
]
{
[
&
]
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment