Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8ce9c50d
"gpu/gpu_info_nvml.c" did not exist on "1d1eb1688cf46c4b9aa599047d98ffc4d723b692"
Unverified
Commit
8ce9c50d
authored
Sep 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 02, 2023
Browse files
Avoid compiling kernels for double data type (#933)
parent
32b6816e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
21 deletions
+28
-21
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+4
-6
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+5
-9
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+14
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+2
-3
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+3
-3
No files found.
csrc/activation_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace
vllm
{
template
<
typename
T
>
...
...
@@ -34,9 +36,7 @@ void silu_and_mul(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
...
...
@@ -71,9 +71,7 @@ __global__ void activation_kernel(
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
...
...
csrc/cache_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include <algorithm>
#include <cassert>
#include <map>
...
...
@@ -125,9 +127,7 @@ void copy_blocks(
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
vllm
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
...
...
@@ -202,9 +202,7 @@ void reshape_and_cache(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
[
&
]
{
...
...
@@ -364,9 +362,7 @@ void gather_cached_kv(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"gather_cached_kv_kernel_optimized"
,
[
&
]
{
...
...
csrc/dispatch_utils.h
0 → 100644
View file @
8ce9c50d
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
namespace
vllm
{
...
...
@@ -46,9 +47,7 @@ void rms_norm(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
...
...
csrc/pos_encoding_kernels.cu
View file @
8ce9c50d
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace
vllm
{
template
<
typename
scalar_t
>
...
...
@@ -83,9 +85,7 @@ void rotary_embedding_neox(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding_neox"
,
[
&
]
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment