Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9da5a40
Unverified
Commit
e9da5a40
authored
Apr 11, 2024
by
Kunshang Ji
Committed by
GitHub
Apr 10, 2024
Browse files
[Misc] Add indirection layer for custom ops (#3913)
parent
e42df722
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
224 additions
and
32 deletions
+224
-32
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+1
-1
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+3
-3
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+12
-13
vllm/_custom_ops.py
vllm/_custom_ops.py
+193
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+5
-5
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+1
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-1
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-1
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+1
-1
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+1
-1
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+1
-1
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+1
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+1
-1
vllm/utils.py
vllm/utils.py
+2
-2
No files found.
benchmarks/kernels/benchmark_paged_attention.py
View file @
e9da5a40
...
...
@@ -5,7 +5,7 @@ from typing import Optional
import
torch
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random
NUM_BLOCKS
=
1024
...
...
tests/kernels/test_attention.py
View file @
e9da5a40
...
...
@@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm
._C
import
cache
_ops
,
ops
from
vllm
import
_custom
_ops
as
ops
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
...
...
@@ -237,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
device
)
cache_
ops
.
convert_fp8
(
key_cache
,
dequantized_key_cache
)
ops
.
convert_fp8
(
key_cache
,
dequantized_key_cache
)
key_cache
=
dequantized_key_cache
value_cache_shape
=
value_cache
.
shape
dequantized_value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
device
)
cache_
ops
.
convert_fp8
(
value_cache
,
dequantized_value_cache
)
ops
.
convert_fp8
(
value_cache
,
dequantized_value_cache
)
value_cache
=
dequantized_value_cache
ref_output
=
torch
.
empty_like
(
query
)
...
...
tests/kernels/test_cache.py
View file @
e9da5a40
...
...
@@ -4,7 +4,7 @@ from typing import Tuple
import
pytest
import
torch
from
vllm
._C
import
cache_
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
is_hip
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
...
...
@@ -80,7 +80,7 @@ def test_copy_blocks(
cloned_value_caches
=
[
value_cache
.
clone
()
for
value_cache
in
value_caches
]
# Call the copy blocks kernel.
cache_
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
# Run the reference implementation.
for
src
,
dsts
in
block_mapping
.
items
():
...
...
@@ -145,9 +145,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cache_
ops
.
convert_fp8
(
key_cache
,
cloned_key_cache
)
ops
.
convert_fp8
(
key_cache
,
cloned_key_cache
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cache_
ops
.
convert_fp8
(
value_cache
,
cloned_value_cache
)
ops
.
convert_fp8
(
value_cache
,
cloned_value_cache
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
...
...
@@ -156,14 +156,14 @@ def test_reshape_and_cache(
kv_scale
=
1.0
# Call the reshape_and_cache kernel.
cache_
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cache_
ops
.
convert_fp8
(
key_cache
,
result_key_cache
)
ops
.
convert_fp8
(
key_cache
,
result_key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cache_
ops
.
convert_fp8
(
value_cache
,
result_value_cache
)
ops
.
convert_fp8
(
value_cache
,
result_value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
...
...
@@ -251,9 +251,8 @@ def test_swap_blocks(
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
cache_ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping
)
cache_ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping
)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping
)
for
src
,
dst
in
block_mapping
.
items
():
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
...
...
@@ -291,9 +290,9 @@ def test_fp8_conversion(
cache
.
uniform_
(
low
,
high
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
cache_
ops
.
convert_fp8
(
cache
,
cache_fp8
)
ops
.
convert_fp8
(
cache
,
cache_fp8
)
converted_cache
=
torch
.
empty_like
(
cache
)
cache_
ops
.
convert_fp8
(
cache_fp8
,
converted_cache
)
ops
.
convert_fp8
(
cache_fp8
,
converted_cache
)
assert
torch
.
allclose
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
vllm/_custom_ops.py
0 → 100644
View file @
e9da5a40
from
typing
import
Dict
,
Optional
import
torch
try
:
from
vllm._C
import
cache_ops
as
vllm_cache_ops
from
vllm._C
import
ops
as
vllm_ops
except
ImportError
:
pass
# activation ops
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
silu_and_mul
(
out
,
x
)
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_and_mul
(
out
,
x
)
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_tanh_and_mul
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_fast
(
out
,
x
)
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_new
(
out
,
x
)
# page attention ops
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_context_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
def
paged_attention_v2
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_context_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_ops
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
# pos encoding ops
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
)
->
None
:
vllm_ops
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
# layer norm ops
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
vllm_ops
.
rms_norm
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
vllm_ops
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# quantization ops
# awq
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
awq_dequantize
(
qweight
,
scales
,
zeros
,
split_k_iters
,
thx
,
thy
)
def
awq_gemm
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
# gptq
def
gptq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
vllm_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
lookup_table
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
# marlin
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
size_m
,
size_n
,
size_k
)
# moe
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
)
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
Dict
[
int
,
int
])
->
None
:
vllm_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
def
convert_fp8
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
)
->
None
:
vllm_cache_ops
.
convert_fp8
(
output
,
input
)
#TODO: cuda_utils, custom_ar
vllm/attention/ops/paged_attn.py
View file @
e9da5a40
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
vllm
._C
import
cache
_ops
,
ops
from
vllm
import
_custom
_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...
...
@@ -69,7 +69,7 @@ class PagedAttention:
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
cache_
ops
.
reshape_and_cache
(
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
...
...
@@ -199,11 +199,11 @@ class PagedAttention:
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
...
...
@@ -212,4 +212,4 @@ class PagedAttention:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
vllm/model_executor/layers/activation.py
View file @
e9da5a40
...
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
e9da5a40
...
...
@@ -8,7 +8,7 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
...
...
vllm/model_executor/layers/layernorm.py
View file @
e9da5a40
...
...
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
class
RMSNorm
(
nn
.
Module
):
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
e9da5a40
...
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
e9da5a40
...
...
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
vllm/model_executor/layers/quantization/marlin.py
View file @
e9da5a40
...
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
e9da5a40
...
...
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
e9da5a40
...
...
@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/utils.py
View file @
e9da5a40
...
...
@@ -279,10 +279,10 @@ def _generate_random_fp8(
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from
vllm
._C
import
cache_
ops
from
vllm
import
_custom_ops
as
ops
tensor_tmp
=
torch
.
empty_like
(
tensor
,
dtype
=
torch
.
float16
)
tensor_tmp
.
uniform_
(
low
,
high
)
cache_
ops
.
convert_fp8
(
tensor_tmp
,
tensor
)
ops
.
convert_fp8
(
tensor_tmp
,
tensor
)
del
tensor_tmp
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment