Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
77af974b
Unverified
Commit
77af974b
authored
Jan 03, 2024
by
Jee Li
Committed by
GitHub
Jan 02, 2024
Browse files
[FIX] Support non-zero CUDA devices in custom kernels (#1959)
parent
4934d492
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
74 additions
and
30 deletions
+74
-30
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+4
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+3
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+5
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+3
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+2
-0
csrc/quantization/squeezellm/quant_cuda_kernel.cu
csrc/quantization/squeezellm/quant_cuda_kernel.cu
+2
-1
tests/kernels/conftest.py
tests/kernels/conftest.py
+3
-2
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+13
-3
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+15
-10
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+11
-6
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+6
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+7
-4
No files found.
csrc/activation_kernels.cu
View file @
77af974b
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
...
...
@@ -36,6 +37,7 @@ void silu_and_mul(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
...
...
@@ -71,6 +73,7 @@ __global__ void activation_kernel(
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
...
...
csrc/attention/attention_kernels.cu
View file @
77af974b
...
...
@@ -21,6 +21,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
...
...
@@ -616,6 +617,7 @@ void paged_attention_v1_launcher(
dim3
grid
(
num_heads
,
num_seqs
,
1
);
dim3
block
(
NUM_THREADS
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we only compile for the
...
...
@@ -784,6 +786,7 @@ void paged_attention_v2_launcher(
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
dim3
block
(
NUM_THREADS
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we only compile for the
...
...
csrc/cache_kernels.cu
View file @
77af974b
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
...
...
@@ -33,6 +34,7 @@ void swap_blocks(
char
*
dst_ptr
=
static_cast
<
char
*>
(
dst
.
data_ptr
());
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_device
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for
(
const
auto
&
pair
:
block_mapping
)
{
...
...
@@ -127,6 +129,7 @@ void copy_blocks(
const
int
numel_per_block
=
key_caches
[
0
][
0
].
numel
();
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
cache_device
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
...
...
@@ -207,6 +210,7 @@ void reshape_and_cache(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
...
...
@@ -367,6 +371,7 @@ void gather_cached_kv(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
...
...
csrc/layernorm_kernels.cu
View file @
77af974b
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
...
...
@@ -76,6 +77,7 @@ void rms_norm(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
...
...
@@ -101,6 +103,7 @@ void fused_add_rms_norm(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
...
...
csrc/pos_encoding_kernels.cu
View file @
77af974b
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
...
...
@@ -94,6 +95,7 @@ void rotary_embedding(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
...
...
csrc/quantization/squeezellm/quant_cuda_kernel.cu
View file @
77af974b
...
...
@@ -7,6 +7,7 @@
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
...
...
@@ -199,7 +200,7 @@ void squeezellm_gemm(
(
width
+
BLOCKWIDTH
-
1
)
/
BLOCKWIDTH
);
dim3
threads
(
BLOCKWIDTH
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
vec
));
vllm
::
squeezellm
::
NUQ4MatMulKernel
<<<
blocks
,
threads
>>>
(
#ifndef USE_ROCM
(
half2
*
)
vec
.
data
<
at
::
Half
>
(),
...
...
tests/kernels/conftest.py
View file @
77af974b
...
...
@@ -12,6 +12,7 @@ def create_kv_caches(
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
...
...
@@ -23,7 +24,7 @@ def create_kv_caches(
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
device
=
device
)
key_cache
.
uniform_
(
-
scale
,
scale
)
key_caches
.
append
(
key_cache
)
...
...
@@ -32,7 +33,7 @@ def create_kv_caches(
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
device
=
device
)
value_cache
.
uniform_
(
-
scale
,
scale
)
value_caches
.
append
(
value_cache
)
return
key_caches
,
value_caches
...
...
tests/kernels/test_activation.py
View file @
77af974b
...
...
@@ -7,22 +7,26 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
D
=
[
512
,
4096
,
5120
,
13824
]
# Arbitrary values for testing
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_silu_and_mul
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
"cuda"
)
gpu_id
=
f
"cuda:
{
device
}
"
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
gpu_id
)
layer
=
SiluAndMul
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
...
...
@@ -33,16 +37,19 @@ def test_silu_and_mul(
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_gelu_new
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
gpu_id
=
f
"cuda:
{
device
}
"
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
gpu_id
)
layer
=
NewGELU
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
...
...
@@ -53,15 +60,18 @@ def test_gelu_new(
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
def
test_gelu_fast
(
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
gpu_id
=
f
"cuda:
{
device
}
"
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
gpu_id
)
layer
=
FastGELU
()
out
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
...
...
tests/kernels/test_attention.py
View file @
77af974b
...
...
@@ -24,6 +24,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
def
ref_masked_attention
(
...
...
@@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias
=
None
if
alibi_slopes
is
not
None
:
# Create the ALiBi bias used in the paged attention kernel.
position_ids
=
torch
.
arange
(
context_len
,
device
=
"cuda"
).
int
()
position_ids
=
torch
.
arange
(
context_len
,
device
=
query
.
device
).
int
()
alibi_bias
=
(
position_ids
-
context_len
+
1
).
float
()
alibi_bias
=
alibi_slopes
.
view
(
-
1
,
1
,
1
)
*
alibi_bias
.
view
(
1
,
1
,
-
1
)
...
...
@@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
def
test_paged_attention
(
kv_cache_factory
,
version
:
str
,
...
...
@@ -115,18 +117,19 @@ def test_paged_attention(
block_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
gpu_id
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
...
...
@@ -135,12 +138,12 @@ def test_paged_attention(
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
device
=
gpu_id
)
context_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
context_lens
[
-
1
]
=
MAX_SEQ_LEN
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
gpu_id
)
# Create the block tables.
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
...
...
@@ -151,12 +154,12 @@ def test_paged_attention(
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
gpu_id
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
dtype
,
seed
)
seed
,
gpu_id
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Call the paged attention kernel.
...
...
@@ -249,7 +252,7 @@ def ref_multi_query_kv_attention(
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
"cuda"
)
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
query
.
device
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
...
...
@@ -269,6 +272,7 @@ def ref_multi_query_kv_attention(
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
...
...
@@ -276,11 +280,12 @@ def test_multi_query_kv_attention(
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
...
...
@@ -294,7 +299,7 @@ def test_multi_query_kv_attention(
num_query_heads
+
2
*
num_kv_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
gpu_id
)
qkv
.
uniform_
(
-
scale
,
scale
)
query
,
key
,
value
=
qkv
.
split
(
[
num_query_heads
,
num_kv_heads
,
num_kv_heads
],
dim
=
1
)
...
...
tests/kernels/test_cache.py
View file @
77af974b
...
...
@@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS
=
[
1024
,
36000
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
...
...
@@ -24,6 +25,7 @@ SEEDS = [0]
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_copy_blocks
(
kv_cache_factory
,
...
...
@@ -35,11 +37,12 @@ def test_copy_blocks(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
...
...
@@ -56,7 +59,7 @@ def test_copy_blocks(
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
num_layers
,
num_heads
,
head_size
,
dtype
,
seed
)
head_size
,
dtype
,
seed
,
gpu_id
)
# Clone the KV caches.
cloned_key_caches
=
[
key_cache
.
clone
()
for
key_cache
in
key_caches
]
...
...
@@ -88,6 +91,7 @@ def test_copy_blocks(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_reshape_and_cache
(
kv_cache_factory
,
...
...
@@ -98,28 +102,29 @@ def test_reshape_and_cache(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
gpu_id
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
gpu_id
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
1
,
num_heads
,
head_size
,
dtype
,
seed
)
seed
,
gpu_id
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Clone the KV caches.
...
...
tests/kernels/test_layernorm.py
View file @
77af974b
...
...
@@ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES
=
[
768
,
5120
,
8192
]
# Arbitrary values for testing
ADD_RESIDUAL
=
[
False
,
True
]
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -15,6 +16,7 @@ SEEDS = [0]
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_rms_norm
(
num_tokens
:
int
,
...
...
@@ -22,14 +24,15 @@ def test_rms_norm(
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
(
)
gpu_id
=
f
"cuda:
{
device
}
"
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
,
device
=
gpu_id
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
gpu_id
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
...
...
tests/kernels/test_pos_encoding.py
View file @
77af974b
...
...
@@ -13,6 +13,7 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
...
@@ -23,6 +24,7 @@ SEEDS = [0]
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
is_neox_style
:
bool
,
...
...
@@ -33,6 +35,7 @@ def test_rotary_embedding(
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
...
...
@@ -40,20 +43,20 @@ def test_rotary_embedding(
rotary_dim
=
head_size
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
rope
=
rope
.
to
(
dtype
).
cuda
(
)
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
gpu_id
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
"cuda"
)
device
=
gpu_id
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
gpu_id
)
key
=
torch
.
randn_like
(
query
)
# NOTE(woosuk): The reference implementation should be executed first
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment