Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
724 additions
and
266 deletions
+724
-266
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+1
-0
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+48
-20
tests/tensorizer_loader/conftest.py
tests/tensorizer_loader/conftest.py
+48
-0
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+90
-80
tests/test_config.py
tests/test_config.py
+4
-2
tests/test_scalartype.py
tests/test_scalartype.py
+36
-0
tests/utils.py
tests/utils.py
+99
-7
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+1
-0
vllm/_core_ext.py
vllm/_core_ext.py
+177
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+52
-68
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+49
-13
vllm/adapter_commons/models.py
vllm/adapter_commons/models.py
+15
-15
vllm/adapter_commons/request.py
vllm/adapter_commons/request.py
+5
-5
vllm/adapter_commons/worker_manager.py
vllm/adapter_commons/worker_manager.py
+7
-7
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-0
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+3
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+31
-16
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+45
-28
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-2
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+6
-3
No files found.
tests/spec_decode/test_batch_expansion.py
View file @
e661d594
...
@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
...
@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id
,
input_seq_id
,
target_seq_id
,
target_seq_id
,
token_ids
,
token_ids
,
input_seq_group_metadata
.
sampling_params
,
)
)
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
e661d594
...
@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
...
@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
exception_secret
=
'artificial stop'
exception_secret
=
'artificial stop'
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
draft_worker
.
get_spec_proposals
.
side_effect
=
ValueError
(
exception_secret
)
...
@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
...
@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
vocab_size
=
32_000
vocab_size
=
32_000
...
@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
...
@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
metrics_collector
)
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
...
@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
...
@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
worker
=
SpecDecodeWorker
(
draft_worker
,
metrics_collector
)
target_worker
,
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
worker
.
init_device
()
worker
.
init_device
()
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
proposal_token_ids
=
torch
.
randint
(
low
=
0
,
...
@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
...
@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
proposer_worker
=
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
False
,
scorer_worker
=
target_worker
,
metrics_collector
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
...
@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
...
@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
set_random_seed
(
1
)
set_random_seed
(
1
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
proposer_worker
=
draft_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
False
,
scorer_worker
=
target_worker
,
metrics_collector
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
,
k
,
...
@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
...
@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
draft_worker
,
target_worker
,
spec_decode_sampler
,
worker
=
SpecDecodeWorker
(
False
,
metrics_collector
)
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
,
spec_decode_sampler
=
spec_decode_sampler
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
,
)
worker
.
init_device
()
worker
.
init_device
()
draft_worker
.
init_device
.
assert_called_once
()
draft_worker
.
init_device
.
assert_called_once
()
...
@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
...
@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
target_worker
=
mock_worker
()
target_worker
=
mock_worker
()
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
metrics_collector
=
MagicMock
(
spec
=
AsyncMetricsCollector
)
worker
=
SpecDecodeWorker
(
worker
=
SpecDecodeWorker
(
proposer_worker
=
draft_worker
,
draft_worker
,
target_worker
,
scorer_worker
=
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
)
spec_decode_sampler
=
mock_spec_decode_sampler
(
acceptance_sampler_method
),
metrics_collector
=
metrics_collector
)
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
kwargs
=
{
"num_gpu_blocks"
:
1024
,
"num_cpu_blocks"
:
1023
}
worker
.
initialize_cache
(
**
kwargs
)
worker
.
initialize_cache
(
**
kwargs
)
...
@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
...
@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
accepted_token_ids
=
accepted_token_ids
,
accepted_token_ids
=
accepted_token_ids
,
target_logprobs
=
target_token_logprobs
,
target_logprobs
=
target_token_logprobs
,
k
=
k
)
k
=
k
,
stage_times
=
(
0
,
0
,
0
))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# _seq_with_bonus_token_in_last_step but were not part of the current
...
...
tests/tensorizer_loader/conftest.py
0 → 100644
View file @
e661d594
import
contextlib
import
functools
import
gc
import
pytest
import
ray
import
torch
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
@
pytest
.
fixture
(
autouse
=
True
)
def
cleanup
():
destroy_model_parallel
()
destroy_distributed_environment
()
with
contextlib
.
suppress
(
AssertionError
):
torch
.
distributed
.
destroy_process_group
()
ray
.
shutdown
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
retry_until_skip
(
n
):
def
decorator_retry
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper_retry
(
*
args
,
**
kwargs
):
for
i
in
range
(
n
):
try
:
return
func
(
*
args
,
**
kwargs
)
except
AssertionError
:
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
if
i
==
n
-
1
:
pytest
.
skip
(
"Skipping test after attempts.."
)
return
wrapper_retry
return
decorator_retry
@
pytest
.
fixture
(
autouse
=
True
)
def
tensorizer_config
():
config
=
TensorizerConfig
(
tensorizer_uri
=
"vllm"
)
return
config
tests/tensorizer_loader/test_tensorizer.py
View file @
e661d594
import
gc
import
json
import
json
import
os
import
os
import
pathlib
import
pathlib
...
@@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
...
@@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
serialize_vllm_model
,
serialize_vllm_model
,
tensorize_vllm_model
)
tensorize_vllm_model
)
from
..conftest
import
VllmRunner
,
cleanup
from
..conftest
import
VllmRunner
from
..utils
import
RemoteOpenAIServer
from
..utils
import
RemoteOpenAIServer
from
.conftest
import
retry_until_skip
# yapf conflicts with isort for this docstring
# yapf conflicts with isort for this docstring
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -48,14 +49,16 @@ def is_curl_installed():
...
@@ -48,14 +49,16 @@ def is_curl_installed():
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
return
False
return
False
def
get_torch_model
(
vllm_runner
:
VllmRunner
):
def
get_torch_model
(
vllm_runner
:
VllmRunner
):
return
vllm_runner
\
return
vllm_runner
\
.
model
\
.
model
\
.
llm_engine
\
.
llm_engine
\
.
model_executor
\
.
model_executor
\
.
driver_worker
\
.
driver_worker
\
.
model_runner
\
.
model_runner
\
.
model
.
model
def
write_keyfile
(
keyfile_path
:
str
):
def
write_keyfile
(
keyfile_path
:
str
):
encryption_params
=
EncryptionParams
.
random
()
encryption_params
=
EncryptionParams
.
random
()
...
@@ -63,11 +66,6 @@ def write_keyfile(keyfile_path: str):
...
@@ -63,11 +66,6 @@ def write_keyfile(keyfile_path: str):
with
open
(
keyfile_path
,
'wb'
)
as
f
:
with
open
(
keyfile_path
,
'wb'
)
as
f
:
f
.
write
(
encryption_params
.
key
)
f
.
write
(
encryption_params
.
key
)
@
pytest
.
fixture
(
autouse
=
True
)
def
tensorizer_config
():
config
=
TensorizerConfig
(
tensorizer_uri
=
"vllm"
)
return
config
@
patch
(
'vllm.model_executor.model_loader.tensorizer.TensorizerAgent'
)
@
patch
(
'vllm.model_executor.model_loader.tensorizer.TensorizerAgent'
)
def
test_load_with_tensorizer
(
mock_agent
,
tensorizer_config
):
def
test_load_with_tensorizer
(
mock_agent
,
tensorizer_config
):
...
@@ -90,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
...
@@ -90,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
with
vllm_runner
(
model_ref
,
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
tensorized_path
,
tensorizer_uri
=
tensorized_path
,
num_readers
=
1
,
num_readers
=
1
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
))
as
loaded_hf_model
:
))
as
loaded_hf_model
:
deserialized_outputs
=
loaded_hf_model
.
generate
(
prompts
,
deserialized_outputs
=
loaded_hf_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
sampling_params
)
# noqa: E501
assert
deserialized_outputs
assert
deserialized_outputs
...
@@ -117,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
...
@@ -117,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
encryption_keyfile
=
key_path
encryption_keyfile
=
key_path
)
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
config_for_serializing
)
config_for_serializing
)
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
)
encryption_keyfile
=
key_path
)
with
vllm_runner
(
with
vllm_runner
(
model_ref
,
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
config_for_deserializing
)
as
loaded_vllm_model
:
# noqa: E501
model_loader_extra_config
=
config_for_deserializing
)
as
loaded_vllm_model
:
# noqa: E501
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
assert
outputs
==
deserialized_outputs
assert
outputs
==
deserialized_outputs
...
@@ -144,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
...
@@ -144,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
serializer
.
write_module
(
hf_model
.
model
)
serializer
.
write_module
(
hf_model
.
model
)
with
vllm_runner
(
model_ref
,
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
tensorizer_uri
=
model_path
,
num_readers
=
1
,
num_readers
=
1
,
))
as
loaded_hf_model
:
))
as
loaded_hf_model
:
deserialized_outputs
=
loaded_hf_model
.
generate_greedy
(
deserialized_outputs
=
loaded_hf_model
.
generate_greedy
(
prompts
,
max_tokens
=
max_tokens
)
prompts
,
max_tokens
=
max_tokens
)
...
@@ -171,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
...
@@ -171,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
TensorizerConfig
(
tensorizer_uri
=
model_path
))
TensorizerConfig
(
tensorizer_uri
=
model_path
))
with
vllm_runner
(
with
vllm_runner
(
model_ref
,
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
tensorizer_uri
=
model_path
,
num_readers
=
1
,
num_readers
=
1
,
),
),
enable_lora
=
True
,
enable_lora
=
True
,
max_loras
=
1
,
max_loras
=
1
,
max_lora_rank
=
8
,
max_lora_rank
=
8
,
max_cpu_loras
=
2
,
max_cpu_loras
=
2
,
max_num_seqs
=
50
,
max_num_seqs
=
50
,
max_model_len
=
1000
,
max_model_len
=
1000
,
)
as
loaded_vllm_model
:
)
as
loaded_vllm_model
:
process_requests
(
loaded_vllm_model
.
model
.
llm_engine
,
test_prompts
)
process_requests
(
loaded_vllm_model
.
model
.
llm_engine
,
test_prompts
)
...
@@ -193,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
...
@@ -193,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def
test_load_without_tensorizer_load_format
(
vllm_runner
):
def
test_load_without_tensorizer_load_format
(
vllm_runner
):
model
=
None
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
vllm_runner
(
model
=
vllm_runner
(
model_ref
,
model_ref
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
))
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
))
del
model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
...
@@ -206,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -206,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
TensorizerConfig
(
tensorizer_uri
=
model_path
))
TensorizerConfig
(
tensorizer_uri
=
model_path
))
model_loader_extra_config
=
{
model_loader_extra_config
=
{
"tensorizer_uri"
:
str
(
model_path
),
"tensorizer_uri"
:
str
(
model_path
),
...
@@ -224,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -224,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
client
=
server
.
get_client
()
client
=
server
.
get_client
()
completion
=
client
.
completions
.
create
(
model
=
model_ref
,
completion
=
client
.
completions
.
create
(
model
=
model_ref
,
prompt
=
"Hello, my name is"
,
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
max_tokens
=
5
,
temperature
=
0.0
)
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
id
is
not
None
assert
len
(
completion
.
choices
)
==
1
assert
len
(
completion
.
choices
)
==
1
...
@@ -237,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -237,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
):
def
test_raise_value_error_on_invalid_load_format
(
vllm_runner
):
model
=
None
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
vllm_runner
(
model
=
vllm_runner
(
model_ref
,
model_ref
,
load_format
=
"safetensors"
,
load_format
=
"safetensors"
,
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
))
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
"test"
))
del
model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
...
@@ -263,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
...
@@ -263,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
disable_custom_all_reduce
=
True
,
disable_custom_all_reduce
=
True
,
)
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Requires 2 GPUs"
)
reason
=
"Requires 2 GPUs"
)
def
test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs
(
vllm_runner
,
def
test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs
(
vllm_runner
,
tmp_path
):
tmp_path
):
model_ref
=
"EleutherAI/pythia-1.4b"
model_ref
=
"EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
# record outputs from un-sharded un-tensorized model
base_model
=
vllm_runner
(
with
vllm_runner
(
model_ref
,
model_ref
,
disable_custom_all_reduce
=
True
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
enforce_eager
=
True
,
)
)
as
base_model
:
outputs
=
base_model
.
generate
(
prompts
,
sampling_params
)
outputs
=
base_model
.
generate
(
prompts
,
sampling_params
)
base_model
.
model
.
llm_engine
.
model_executor
.
shutdown
()
base_model
.
model
.
llm_engine
.
model_executor
.
shutdown
()
del
base_model
cleanup
()
# load model with two shards and serialize with encryption
# load model with two shards and serialize with encryption
model_path
=
str
(
tmp_path
/
(
model_ref
+
"-%02d.tensors"
))
model_path
=
str
(
tmp_path
/
(
model_ref
+
"-%02d.tensors"
))
...
@@ -291,31 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
...
@@ -291,31 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tensorize_vllm_model
(
tensorize_vllm_model
(
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model_ref
,
model
=
model_ref
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
disable_custom_all_reduce
=
True
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
enforce_eager
=
True
,
),
),
tensorizer_config
=
tensorizer_config
,
tensorizer_config
=
tensorizer_config
,
)
)
assert
os
.
path
.
isfile
(
model_path
%
0
),
"Serialization subprocess failed"
assert
os
.
path
.
isfile
(
model_path
%
0
),
"Serialization subprocess failed"
assert
os
.
path
.
isfile
(
model_path
%
1
),
"Serialization subprocess failed"
assert
os
.
path
.
isfile
(
model_path
%
1
),
"Serialization subprocess failed"
cleanup
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
tensor_parallel_size
=
2
,
load_format
=
"tensorizer"
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
model_loader_extra_config
=
tensorizer_config
)
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
with
vllm_runner
(
model_ref
,
tensor_parallel_size
=
2
,
load_format
=
"tensorizer"
,
disable_custom_all_reduce
=
True
,
enforce_eager
=
True
,
model_loader_extra_config
=
tensorizer_config
)
as
loaded_vllm_model
:
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
assert
outputs
==
deserialized_outputs
assert
outputs
==
deserialized_outputs
@
retry_until_skip
(
3
)
def
test_vllm_tensorized_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
def
test_vllm_tensorized_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
model_ref
=
"facebook/opt-125m"
model_ref
=
"facebook/opt-125m"
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
))
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
))
...
@@ -327,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
...
@@ -327,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
assert
is_vllm_tensorized
(
config
)
assert
is_vllm_tensorized
(
config
)
with
vllm_runner
(
model_ref
,
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
config
)
as
loaded_vllm_model
:
model_loader_extra_config
=
config
)
as
loaded_vllm_model
:
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
assert
outputs
==
deserialized_outputs
assert
outputs
==
deserialized_outputs
tests/test_config.py
View file @
e661d594
...
@@ -104,8 +104,10 @@ def test_rope_customization():
...
@@ -104,8 +104,10 @@ def test_rope_customization():
dtype
=
"float16"
,
dtype
=
"float16"
,
seed
=
0
,
seed
=
0
,
)
)
assert
getattr
(
longchat_model_config
.
hf_config
,
"rope_scaling"
,
# Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
None
)
==
LONGCHAT_ROPE_SCALING
assert
all
(
longchat_model_config
.
hf_config
.
rope_scaling
.
get
(
key
)
==
value
for
key
,
value
in
LONGCHAT_ROPE_SCALING
.
items
())
assert
longchat_model_config
.
max_model_len
==
16384
assert
longchat_model_config
.
max_model_len
==
16384
longchat_model_config
=
ModelConfig
(
longchat_model_config
=
ModelConfig
(
...
...
tests/test_scalartype.py
0 → 100644
View file @
e661d594
import
pytest
import
torch
from
vllm.scalar_type
import
scalar_types
@
pytest
.
mark
.
parametrize
(
"type_tuple"
,
(
(
-
8
,
7
,
scalar_types
.
int4
),
(
0
,
15
,
scalar_types
.
uint4
),
(
-
8
,
7
,
scalar_types
.
uint4b8
),
(
-
128
,
127
,
scalar_types
.
uint8b128
),
(
-
28.
,
28.
,
scalar_types
.
float6_e3m2f
),
(
torch
.
int8
,
scalar_types
.
int8
),
(
torch
.
uint8
,
scalar_types
.
uint8
),
(
torch
.
float8_e5m2
,
scalar_types
.
float8_e5m2
),
(
torch
.
float8_e4m3fn
,
scalar_types
.
float8_e4m3fn
),
(
torch
.
bfloat16
,
scalar_types
.
float16_e8m7
),
(
torch
.
float16
,
scalar_types
.
float16_e5m10
),
),
ids
=
lambda
x
:
str
(
x
))
def
test_scalar_type_min_max
(
type_tuple
):
print
(
type_tuple
)
if
len
(
type_tuple
)
==
3
:
min
,
max
,
t
=
type_tuple
else
:
torch_type
,
t
=
type_tuple
if
torch_type
.
is_floating_point
:
min
=
torch
.
finfo
(
torch_type
).
min
max
=
torch
.
finfo
(
torch_type
).
max
else
:
min
=
torch
.
iinfo
(
torch_type
).
min
max
=
torch
.
iinfo
(
torch_type
).
max
print
(
t
,
min
,
max
,
t
.
min
(),
t
.
max
())
assert
min
==
t
.
min
()
assert
max
==
t
.
max
()
tests/utils.py
View file @
e661d594
import
functools
import
os
import
os
import
signal
import
subprocess
import
subprocess
import
sys
import
sys
import
time
import
time
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
import
openai
import
openai
import
ray
import
ray
...
@@ -48,13 +50,14 @@ VLLM_PATH = Path(__file__).parent.parent
...
@@ -48,13 +50,14 @@ VLLM_PATH = Path(__file__).parent.parent
class
RemoteOpenAIServer
:
class
RemoteOpenAIServer
:
DUMMY_API_KEY
=
"token-abc123"
# vLLM's OpenAI server does not need API key
DUMMY_API_KEY
=
"token-abc123"
# vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S
=
60
0
# wait for server to start for
6
0 seconds
MAX_SERVER_START_WAIT_S
=
12
0
# wait for server to start for
12
0 seconds
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
cli_args
:
List
[
str
],
cli_args
:
List
[
str
],
*
,
*
,
env_dict
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
auto_port
:
bool
=
True
,
auto_port
:
bool
=
True
,
)
->
None
:
)
->
None
:
if
auto_port
:
if
auto_port
:
...
@@ -75,6 +78,8 @@ class RemoteOpenAIServer:
...
@@ -75,6 +78,8 @@ class RemoteOpenAIServer:
# the current process might initialize cuda,
# the current process might initialize cuda,
# to be safe, we should use spawn method
# to be safe, we should use spawn method
env
[
'VLLM_WORKER_MULTIPROC_METHOD'
]
=
'spawn'
env
[
'VLLM_WORKER_MULTIPROC_METHOD'
]
=
'spawn'
if
env_dict
is
not
None
:
env
.
update
(
env_dict
)
self
.
proc
=
subprocess
.
Popen
([
"vllm"
,
"serve"
]
+
[
model
]
+
cli_args
,
self
.
proc
=
subprocess
.
Popen
([
"vllm"
,
"serve"
]
+
[
model
]
+
cli_args
,
env
=
env
,
env
=
env
,
stdout
=
sys
.
stdout
,
stdout
=
sys
.
stdout
,
...
@@ -87,6 +92,11 @@ class RemoteOpenAIServer:
...
@@ -87,6 +92,11 @@ class RemoteOpenAIServer:
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
proc
.
terminate
()
self
.
proc
.
terminate
()
try
:
self
.
proc
.
wait
(
3
)
except
subprocess
.
TimeoutExpired
:
# force kill if needed
self
.
proc
.
kill
()
def
_wait_for_server
(
self
,
*
,
url
:
str
,
timeout
:
float
):
def
_wait_for_server
(
self
,
*
,
url
:
str
,
timeout
:
float
):
# run health check
# run health check
...
@@ -125,10 +135,21 @@ class RemoteOpenAIServer:
...
@@ -125,10 +135,21 @@ class RemoteOpenAIServer:
)
)
def
compare_two_settings
(
model
:
str
,
arg1
:
List
[
str
],
arg2
:
List
[
str
]):
def
compare_two_settings
(
model
:
str
,
arg1
:
List
[
str
],
arg2
:
List
[
str
],
env1
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
env2
:
Optional
[
Dict
[
str
,
str
]]
=
None
):
"""
"""
Launch API server with two different sets of arguments and compare the
Launch API server with two different sets of arguments/environments
results of the API calls. The arguments are after the model name.
and compare the results of the API calls.
Args:
model: The model to test.
arg1: The first set of arguments to pass to the API server.
arg2: The second set of arguments to pass to the API server.
env1: The first set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
"""
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
...
@@ -136,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
...
@@ -136,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
prompt
=
"Hello, my name is"
prompt
=
"Hello, my name is"
token_ids
=
tokenizer
(
prompt
)[
"input_ids"
]
token_ids
=
tokenizer
(
prompt
)[
"input_ids"
]
results
=
[]
results
=
[]
for
args
in
(
arg1
,
arg2
):
for
args
,
env
in
(
(
arg1
,
env1
),
(
arg2
,
env2
)
):
with
RemoteOpenAIServer
(
model
,
args
)
as
server
:
with
RemoteOpenAIServer
(
model
,
args
,
env_dict
=
env
)
as
server
:
client
=
server
.
get_client
()
client
=
server
.
get_client
()
# test models list
# test models list
...
@@ -178,6 +199,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
...
@@ -178,6 +199,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage"
:
completion
.
usage
,
"usage"
:
completion
.
usage
,
})
})
# test seeded random sampling
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
prompt
,
max_tokens
=
5
,
seed
=
33
,
temperature
=
1.0
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
completion
.
choices
[
0
].
text
,
"finish_reason"
:
completion
.
choices
[
0
].
finish_reason
,
"usage"
:
completion
.
usage
,
})
# test seeded random sampling with multiple prompts
completion
=
client
.
completions
.
create
(
model
=
model
,
prompt
=
[
prompt
,
prompt
],
max_tokens
=
5
,
seed
=
33
,
temperature
=
1.0
)
results
.
append
({
"test"
:
"seeded_sampling"
,
"text"
:
[
choice
.
text
for
choice
in
completion
.
choices
],
"finish_reason"
:
[
choice
.
finish_reason
for
choice
in
completion
.
choices
],
"usage"
:
completion
.
usage
,
})
# test simple list
# test simple list
batch
=
client
.
completions
.
create
(
batch
=
client
.
completions
.
create
(
model
=
model
,
model
=
model
,
...
@@ -305,3 +357,43 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
...
@@ -305,3 +357,43 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
time
.
sleep
(
5
)
def
fork_new_process_for_each_test
(
f
):
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
,
**
kwargs
):
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os
.
setpgrp
()
from
_pytest.outcomes
import
Skipped
pid
=
os
.
fork
()
if
pid
==
0
:
try
:
f
(
*
args
,
**
kwargs
)
except
Skipped
as
e
:
# convert Skipped to exit code 0
print
(
str
(
e
))
os
.
_exit
(
0
)
except
Exception
:
import
traceback
traceback
.
print_exc
()
os
.
_exit
(
1
)
else
:
os
.
_exit
(
0
)
else
:
pgid
=
os
.
getpgid
(
pid
)
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
# ignore SIGTERM signal itself
old_singla_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
# kill all child processes
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
# restore the signal handler
signal
.
signal
(
signal
.
SIGTERM
,
old_singla_handler
)
assert
_exitcode
==
0
,
(
f
"function
{
f
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
)
return
wrapper
tests/worker/test_model_runner.py
View file @
e661d594
...
@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
for
_
in
range
(
expected_bs
-
len
(
seq_lens
)):
for
_
in
range
(
expected_bs
-
len
(
seq_lens
)):
seq_lens
.
append
(
1
)
seq_lens
.
append
(
1
)
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
num_decode_tokens
==
len
(
seq_lens
)
start_idx
=
0
start_idx
=
0
start_loc
=
[
start_idx
]
start_loc
=
[
start_idx
]
for
_
in
context_lens
:
for
_
in
context_lens
:
...
...
vllm/_core_ext.py
0 → 100644
View file @
e661d594
import
importlib.util
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
core_C_available
=
importlib
.
util
.
find_spec
(
'._core_C'
,
'vllm'
)
is
not
None
# Mirrors enum in `core/scalar_type.hpp`
class
NanRepr
(
Enum
):
NONE
=
0
# nans are not supported
IEEE_754
=
1
# nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN
=
2
# nans are: Exp all 1s, mantissa all 1s
if
TYPE_CHECKING
or
not
core_C_available
:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp
from
dataclasses
import
dataclass
@
dataclass
(
frozen
=
True
)
class
ScalarType
:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent
:
int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa
:
int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias
:
int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
signed
:
bool
"If the type is signed (i.e. has a sign bit)"
_finite_values_only
:
bool
=
False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""
nan_repr
:
int
=
NanRepr
.
IEEE_754
.
value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@
property
def
size_bits
(
self
):
return
self
.
exponent
+
self
.
mantissa
+
int
(
self
.
signed
)
def
min
(
self
)
->
Union
[
int
,
float
]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise
NotImplementedError
def
max
(
self
)
->
Union
[
int
,
float
]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise
NotImplementedError
def
is_signed
(
self
)
->
bool
:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...
def
is_floating_point
(
self
):
"If the type is a floating point type"
return
self
.
exponent
!=
0
def
is_integer
(
self
):
"If the type is an integer type"
return
self
.
exponent
==
0
def
has_bias
(
self
):
"If the type has a non-zero bias"
return
self
.
bias
!=
0
def
has_infs
(
self
):
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
def
has_nans
(
self
):
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
\
not
self
.
_finite_values_only
def
__str__
(
self
)
->
str
:
raise
NotImplementedError
def
__repr__
(
self
)
->
str
:
raise
NotImplementedError
#
# Convenience Constructors
#
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
"Create a signed integer scalar type (size_bits includes sign-bit)."
return
cls
(
size_bits
-
1
,
size_bits
,
bias
if
bias
else
0
,
True
)
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
"""Create a unsigned integer scalar type."""
return
cls
(
size_bits
,
size_bits
,
bias
if
bias
else
0
,
False
)
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
)
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
):
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
,
finite_values_only
,
nan_repr
)
elif
core_C_available
:
try
:
import
vllm._core_C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
vllm/_custom_ops.py
View file @
e661d594
import
contextlib
import
contextlib
import
functools
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
vllm._core_ext
import
ScalarType
from
vllm.logger
import
init_logger
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
try
:
...
@@ -17,12 +19,9 @@ try:
...
@@ -17,12 +19,9 @@ try:
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
import
vllm._moe_C
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
# ruff: noqa: F401
# ruff: noqa: F401
import
vllm._
punica
_C
import
vllm._
moe
_C
def
is_custom_op_supported
(
op_name
:
str
)
->
bool
:
def
is_custom_op_supported
(
op_name
:
str
)
->
bool
:
...
@@ -264,10 +263,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -264,10 +263,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# marlin_24
# marlin_24
def
gptq_marlin_24_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_marlin_24_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
return
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
workspace
,
num_bits
,
size_m
,
workspace
,
b_q_type
,
size_m
,
size_n
,
size_k
)
size_n
,
size_k
)
...
@@ -280,7 +279,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -280,7 +279,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
]
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
...
@@ -323,16 +322,24 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
...
@@ -323,16 +322,24 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
b_zeros
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
g_idx
:
torch
.
Tensor
,
has_zp
:
bool
)
->
torch
.
Tensor
:
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
num_bits
,
g_idx
,
perm
,
workspace
,
b_q_type
,
size_m
,
size_n
,
size_k
,
is_k_full
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
)
has_zp
,
use_fp32_reduce
)
# fp8 marlin
# fp8 marlin
...
@@ -348,7 +355,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -348,7 +355,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# def scaled_fp8_quant(
# def scaled_fp8_quant(
# input: torch.Tensor,
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# scale: Optional[torch.Tensor] = None,
#
batch_dim
_padding: Optional[int] = None,
#
num_token
_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None,
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# use_per_token_if_dynamic: bool = False,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# ) -> Tuple[torch.Tensor, torch.Tensor]:
...
@@ -358,7 +365,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -358,7 +365,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# This function supports both static and dynamic quantization: If you
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensor for downstream kernels that
# optional padding of the output tensor
s
for downstream kernels that
# will benefit from padding.
# will benefit from padding.
# Args:
# Args:
...
@@ -366,7 +373,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -366,7 +373,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# scale: Optional scaling factor for the FP8 quantization
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# per token case
#
batch_dim
_padding: If specified, pad the first dimension
#
num_token
_padding: If specified, pad the first dimension
# of the output to at least this value.
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# in the dynamic quantization case.
...
@@ -375,16 +382,16 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -375,16 +382,16 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# scaling factor.
# """
# """
#
if batch_dim_padding:
#
# This code assumes batch_dim and num_tokens are flattened
#
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]
)
#
assert (input.ndim == 2
)
#
output = torch.empty(
shape
,
#
shape: Union[Tuple[int, int], torch.Size] = input.
shape
#
device=input.device,
#
if num_token_padding:
#
dtype=torch.float8_e4m3fn
)
#
shape = (max(num_token_padding, input.shape[0]), shape[1]
)
#
else:
#
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# if scale is None:
# if scale is None:
# if use_per_token_if_dynamic:
# if use_per_token_if_dynamic:
# scale = torch.empty((
input.numel() // input.
shape[
-1
], 1),
# scale = torch.empty((shape[
0
], 1),
# device=input.device,
# device=input.device,
# dtype=torch.float32)
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
...
@@ -393,6 +400,8 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -393,6 +400,8 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# else:
# # num_token_padding not implemented for this case
# assert (scale.numel() == 1 or num_token_padding is None)
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
# return output, scale
...
@@ -428,6 +437,15 @@ def scaled_int8_quant(
...
@@ -428,6 +437,15 @@ def scaled_int8_quant(
return
output
,
input_scales
return
output
,
input_scales
# qqq ops
def
marlin_qqq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
s_group
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
marlin_qqq_gemm
(
a
,
b_q_weight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
# moe
# moe
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
...
@@ -467,10 +485,13 @@ def reshape_and_cache_flash(
...
@@ -467,10 +485,13 @@ def reshape_and_cache_flash(
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
kv_cache_dtype
,
k_scale
,
v_scale
)
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
...
@@ -546,43 +567,6 @@ def register_graph_buffers(fa: int, handles: List[str],
...
@@ -546,43 +567,6 @@ def register_graph_buffers(fa: int, handles: List[str],
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# punica
def
dispatch_bgmv
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
Tensor
,
layer_idx
:
int
,
scale
:
float
,
)
->
None
:
torch
.
ops
.
_punica_C
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
Tensor
,
layer_idx
:
int
,
scale
:
float
,
h_in
:
int
,
h_out
:
int
,
y_offset
:
int
,
)
->
None
:
torch
.
ops
.
_punica_C
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
,
h_in
,
h_out
,
y_offset
,
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
# TODO: remove this in v0.6.0
names_and_values
=
globals
()
names_and_values
=
globals
()
...
...
vllm/_ipex_ops.py
View file @
e661d594
...
@@ -25,27 +25,33 @@ class ipex_ops:
...
@@ -25,27 +25,33 @@ class ipex_ops:
x2
=
x2
.
reshape
(
num
,
d
)
x2
=
x2
.
reshape
(
num
,
d
)
return
x1
,
x2
return
x1
,
x2
@
staticmethod
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
silu_mul
(
x1
,
x2
,
out
)
ipex
.
llm
.
functional
.
silu_mul
(
x1
,
x2
,
out
)
@
staticmethod
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"none"
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"none"
)
@
staticmethod
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
x1
,
x2
=
ipex_ops
.
_reshape_activation_tensor
(
x
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"tanh"
)
ipex
.
llm
.
functional
.
gelu_mul
(
x1
,
x2
,
out
,
"tanh"
)
@
staticmethod
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
@
staticmethod
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
out
.
copy_
(
torch
.
nn
.
functional
.
gelu
(
x
))
# TODO add implementation of gelu_quick here
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@
staticmethod
def
paged_attention_v1
(
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -78,12 +84,21 @@ class ipex_ops:
...
@@ -78,12 +84,21 @@ class ipex_ops:
).
view
(
num_kv_heads
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
# todo: ipex will refactor namespace
# todo: ipex will refactor namespace
torch
.
xpu
.
paged_attention_v1
(
out
,
query
.
contiguous
(),
torch
.
xpu
.
paged_attention_v1
(
# type: ignore
key_cache
.
view_as
(
value_cache
),
out
,
value_cache
,
head_mapping
,
scale
,
query
.
contiguous
(),
block_tables
,
context_lens
,
block_size
,
key_cache
.
view_as
(
value_cache
),
max_context_len
,
alibi_slopes
)
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
@
staticmethod
def
paged_attention_v2
(
def
paged_attention_v2
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
...
@@ -119,13 +134,24 @@ class ipex_ops:
...
@@ -119,13 +134,24 @@ class ipex_ops:
).
view
(
num_kv_heads
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
1
).
repeat_interleave
(
num_queries_per_tokens
).
flatten
()
# todo: ipex will refactor namespace
# todo: ipex will refactor namespace
torch
.
xpu
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
torch
.
xpu
.
paged_attention_v2
(
# type: ignore
query
.
contiguous
(),
out
,
key_cache
.
view_as
(
value_cache
),
exp_sum
,
value_cache
,
head_mapping
,
block_tables
,
max_logits
,
context_lens
,
scale
,
block_size
,
tmp_out
,
max_context_len
,
alibi_slopes
)
query
.
contiguous
(),
key_cache
.
view_as
(
value_cache
),
value_cache
,
head_mapping
,
block_tables
,
context_lens
,
scale
,
block_size
,
max_context_len
,
alibi_slopes
,
)
@
staticmethod
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
# [batch_size, seq_len]
positions
:
torch
.
Tensor
,
# [batch_size, seq_len]
query
:
torch
.
Tensor
,
# [batch_size, seq_len, num_heads*head_size]
query
:
torch
.
Tensor
,
# [batch_size, seq_len, num_heads*head_size]
...
@@ -158,6 +184,7 @@ class ipex_ops:
...
@@ -158,6 +184,7 @@ class ipex_ops:
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
rotary_dim
,
is_neox
,
positions
)
rotary_dim
,
is_neox
,
positions
)
@
staticmethod
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
...
@@ -189,17 +216,20 @@ class ipex_ops:
...
@@ -189,17 +216,20 @@ class ipex_ops:
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
ipex
.
llm
.
functional
.
rotary_embedding
(
query_rot
,
key_rot
,
sin
,
cos
,
rotary_dim
,
is_neox
,
positions
)
rotary_dim
,
is_neox
,
positions
)
@
staticmethod
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
epsilon
:
float
)
->
None
:
tmp
=
ipex
.
llm
.
functional
.
rms_norm
(
input
,
weight
,
epsilon
)
tmp
=
ipex
.
llm
.
functional
.
rms_norm
(
input
,
weight
,
epsilon
)
out
.
copy_
(
tmp
)
out
.
copy_
(
tmp
)
@
staticmethod
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
tmp
=
ipex
.
llm
.
functional
.
add_rms_norm
(
residual
,
input
,
weight
,
None
,
tmp
=
ipex
.
llm
.
functional
.
add_rms_norm
(
residual
,
input
,
weight
,
None
,
epsilon
,
True
)
epsilon
,
True
)
input
.
copy_
(
tmp
)
input
.
copy_
(
tmp
)
@
staticmethod
def
varlen_attention
(
def
varlen_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -222,6 +252,7 @@ class ipex_ops:
...
@@ -222,6 +252,7 @@ class ipex_ops:
softmax_scale
,
zero_tensors
,
softmax_scale
,
zero_tensors
,
is_causal
,
return_softmax
,
gen_
)
is_causal
,
return_softmax
,
gen_
)
@
staticmethod
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -240,8 +271,13 @@ class ipex_ops:
...
@@ -240,8 +271,13 @@ class ipex_ops:
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
value_caches
:
List
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
torch
.
xpu
.
copy_blocks
(
# type: ignore
key_caches
,
value_caches
,
block_mapping
,
)
@
staticmethod
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
swap_blocks
(
src
,
dst
,
block_mapping
)
torch
.
xpu
.
swap_blocks
(
src
,
dst
,
block_mapping
)
# type: ignore
vllm/adapter_commons/models.py
View file @
e661d594
...
@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
...
@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
super
().
__init__
(
capacity
)
super
().
__init__
(
capacity
)
self
.
deactivate_fn
=
deactivate_fn
self
.
deactivate_fn
=
deactivate_fn
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
T
):
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
]
):
logger
.
debug
(
"Removing adapter int id: %d"
,
key
)
logger
.
debug
(
"Removing adapter int id: %d"
,
key
)
self
.
deactivate_fn
(
key
)
self
.
deactivate_fn
(
key
)
return
super
().
_on_remove
(
key
,
value
)
return
super
().
_on_remove
(
key
,
value
)
...
@@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
...
@@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
adapter_slots
(
self
):
def
adapter_slots
(
self
)
->
int
:
...
raise
NotImplementedError
@
property
@
property
@
abstractmethod
@
abstractmethod
def
capacity
(
self
):
def
capacity
(
self
)
->
int
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
activate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
activate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
add_adapter
(
self
,
adapter
:
Any
)
->
bool
:
def
add_adapter
(
self
,
adapter
:
Any
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
set_adapter_mapping
(
self
,
mapping
:
Any
)
->
None
:
def
set_adapter_mapping
(
self
,
mapping
:
Any
)
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_all_adapters
(
self
):
def
remove_all_adapters
(
self
)
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
def
get_adapter
(
self
,
adapter_id
:
int
)
->
Optional
[
Any
]:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
def
list_adapters
(
self
)
->
Dict
[
int
,
Any
]:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
pin_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
pin_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
vllm/adapter_commons/request.py
View file @
e661d594
from
abc
import
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
@
dataclass
@
dataclass
class
AdapterRequest
:
class
AdapterRequest
(
ABC
)
:
"""
"""
Base class for adapter requests.
Base class for adapter requests.
"""
"""
@
property
@
property
@
abstractmethod
@
abstractmethod
def
adapter_id
(
self
):
def
adapter_id
(
self
)
->
int
:
...
raise
NotImplementedError
def
__post_init__
(
self
):
def
__post_init__
(
self
)
->
None
:
if
self
.
adapter_id
<
1
:
if
self
.
adapter_id
<
1
:
raise
ValueError
(
f
"id must be > 0, got
{
self
.
adapter_id
}
"
)
raise
ValueError
(
f
"id must be > 0, got
{
self
.
adapter_id
}
"
)
...
...
vllm/adapter_commons/worker_manager.py
View file @
e661d594
...
@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
...
@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
is_enabled
(
self
)
->
bool
:
def
is_enabled
(
self
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
set_active_adapters
(
self
,
requests
:
Set
[
Any
],
def
set_active_adapters
(
self
,
requests
:
Set
[
Any
],
mapping
:
Optional
[
Any
])
->
None
:
mapping
:
Optional
[
Any
])
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
add_adapter
(
self
,
adapter_request
:
Any
)
->
bool
:
def
add_adapter
(
self
,
adapter_request
:
Any
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
remove_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
remove_all_adapters
(
self
):
def
remove_all_adapters
(
self
)
->
None
:
...
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
list_adapters
(
self
)
->
Set
[
int
]:
def
list_adapters
(
self
)
->
Set
[
int
]:
...
raise
NotImplementedError
vllm/attention/backends/abstract.py
View file @
e661d594
...
@@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
e661d594
...
@@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
not
None
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
"Alibi not support for blocksparse flash attention."
)
assert
sliding_window
is
None
,
ValueError
(
assert
sliding_window
is
None
,
ValueError
(
"sliding_window is invalid for blocksparse attention."
)
"sliding_window is invalid for blocksparse attention."
)
assert
logits_soft_cap
is
None
,
ValueError
(
"logits_soft_cap is invalid for blocksparse attention."
)
if
"num_heads"
not
in
blocksparse_params
:
if
"num_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_heads"
]
=
num_heads
blocksparse_params
[
"num_heads"
]
=
num_heads
...
...
vllm/attention/backends/flash_attn.py
View file @
e661d594
...
@@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder(
self
.
num_prefills
=
0
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
self
.
input_builder
=
input_builder
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
runner
=
input_builder
.
runner
...
@@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder(
def
_add_seq_group
(
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
1. context length.
2. block table.
2. block table.
...
@@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size.
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
block_table
=
[]
if
inter_data
.
prefix_cache_hit
:
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
...
@@ -272,23 +273,27 @@ class FlashAttentionMetadataBuilder(
...
@@ -272,23 +273,27 @@ class FlashAttentionMetadataBuilder(
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
...
@@ -297,7 +302,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -297,7 +302,7 @@ class FlashAttentionMetadataBuilder(
if
use_captured_graph
:
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
@@ -397,9 +402,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -397,9 +402,11 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"FlashAttention does not support block-sparse attention."
)
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
@@ -410,6 +417,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -410,6 +417,10 @@ class FlashAttentionImpl(AttentionImpl):
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -478,6 +489,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -478,6 +489,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
...
@@ -515,6 +528,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -515,6 +528,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
...
@@ -534,6 +548,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -534,6 +548,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
vllm/attention/backends/flashinfer.py
View file @
e661d594
...
@@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
# Only used by gemma2 model
logits_soft_cap
:
Optional
[
float
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Refer to
# Refer to
...
@@ -135,13 +133,20 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -135,13 +133,20 @@ class FlashInferMetadata(AttentionMetadata):
return
return
assert
self
.
prefill_wrapper
is
not
None
assert
self
.
prefill_wrapper
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
assert
batch_size
>=
0
# The prefill stage does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
self
.
paged_kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
device
=
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
begin_forward
(
self
.
prefill_wrapper
.
begin_forward
(
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
...
@@ -297,26 +302,38 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -297,26 +302,38 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
is_profile_run
:
if
is_profile_run
:
return
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
List
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
last_page_len
=
seq_len
%
self
.
block_size
# If seq_len = 16, block_size = 16,
if
last_page_len
==
0
:
# block_table_bound is 1 with 1 valid block.
last_page_len
=
self
.
block_size
# If seq_len = 15, block_size = 16,
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
self
.
input_builder
.
chunked_prefill_enabled
)
...
@@ -331,7 +348,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -331,7 +348,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
use_captured_graph
:
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
num_decode_tokens
=
batch_size
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
@@ -379,9 +396,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -379,9 +396,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -418,8 +432,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -418,8 +432,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
device
=
device
,
data_type
=
kv_cache_dtype
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
)
logits_soft_cap
=
logits_soft_cap
)
class
FlashInferImpl
(
AttentionImpl
):
class
FlashInferImpl
(
AttentionImpl
):
...
@@ -434,6 +447,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -434,6 +447,7 @@ class FlashInferImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -446,6 +460,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -446,6 +460,7 @@ class FlashInferImpl(AttentionImpl):
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -489,6 +504,8 @@ class FlashInferImpl(AttentionImpl):
...
@@ -489,6 +504,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
1
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
)
query
=
query
.
contiguous
(
query
=
query
.
contiguous
(
...
@@ -518,7 +535,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -518,7 +535,7 @@ class FlashInferImpl(AttentionImpl):
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
query
,
kv_cache
,
kv_cache
,
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
causal
=
True
)
causal
=
True
)
else
:
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
...
@@ -527,5 +544,5 @@ class FlashInferImpl(AttentionImpl):
...
@@ -527,5 +544,5 @@ class FlashInferImpl(AttentionImpl):
query
,
query
,
kv_cache
,
kv_cache
,
sm_scale
=
self
.
scale
,
sm_scale
=
self
.
scale
,
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
)
logits_soft_cap
=
self
.
logits_soft_cap
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/ipex_attn.py
View file @
e661d594
...
@@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"Torch SPDA does not support block-sparse attention."
)
raise
ValueError
(
"IPEX backend does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"IPEX backend does not support logits_soft_cap."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/pallas.py
View file @
e661d594
...
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
...
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
import
torch_xla.experimental.dynamo_set_buffer_donor
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
...
@@ -55,8 +54,8 @@ class PallasMetadata(AttentionMetadata):
...
@@ -55,8 +54,8 @@ class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# Currently, input sequences can only contain all prefills
# or all decoding.
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
...
@@ -92,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -92,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -110,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -110,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
logits_soft_cap
is
not
None
:
raise
NotImplementedError
(
"Attention logits soft-capping is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment