Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int): ...@@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id, input_seq_id,
target_seq_id, target_seq_id,
token_ids, token_ids,
input_seq_group_metadata.sampling_params,
) )
assert output.request_id == input_seq_group_metadata.request_id assert output.request_id == input_seq_group_metadata.request_id
......
...@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, ...@@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
target_worker = mock_worker() target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
...@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, ...@@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
vocab_size = 32_000 vocab_size = 32_000
...@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, ...@@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int, ...@@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(draft_worker,
metrics_collector) target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device() worker.init_device()
proposal_token_ids = torch.randint(low=0, proposal_token_ids = torch.randint(low=0,
...@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int, ...@@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, proposer_worker=draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False, scorer_worker=target_worker,
metrics_collector) spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int, ...@@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
set_random_seed(1) set_random_seed(1)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(
draft_worker, target_worker, proposer_worker=draft_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False, scorer_worker=target_worker,
metrics_collector) spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
...@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str): ...@@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, worker = SpecDecodeWorker(
False, metrics_collector) proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device() worker.init_device()
draft_worker.init_device.assert_called_once() draft_worker.init_device.assert_called_once()
...@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method): ...@@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
target_worker = mock_worker() target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector) metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker( worker = SpecDecodeWorker(proposer_worker=draft_worker,
draft_worker, target_worker, scorer_worker=target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs) worker.initialize_cache(**kwargs)
...@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens(): ...@@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids, accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs, target_logprobs=target_token_logprobs,
k=k) k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following: # Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in # 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current # _seq_with_bonus_token_in_last_step but were not part of the current
......
import contextlib
import functools
import gc
import pytest
import ray
import torch
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@pytest.fixture(autouse=True)
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
ray.shutdown()
gc.collect()
torch.cuda.empty_cache()
def retry_until_skip(n):
def decorator_retry(func):
@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except AssertionError:
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip("Skipping test after attempts..")
return wrapper_retry
return decorator_retry
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
import gc
import json import json
import os import os
import pathlib import pathlib
...@@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, ...@@ -20,13 +21,13 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
serialize_vllm_model, serialize_vllm_model,
tensorize_vllm_model) tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup from ..conftest import VllmRunner
from ..utils import RemoteOpenAIServer from ..utils import RemoteOpenAIServer
from .conftest import retry_until_skip
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -48,14 +49,16 @@ def is_curl_installed(): ...@@ -48,14 +49,16 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
return False return False
def get_torch_model(vllm_runner: VllmRunner): def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \ return vllm_runner \
.model \ .model \
.llm_engine \ .llm_engine \
.model_executor \ .model_executor \
.driver_worker \ .driver_worker \
.model_runner \ .model_runner \
.model .model
def write_keyfile(keyfile_path: str): def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random() encryption_params = EncryptionParams.random()
...@@ -63,11 +66,6 @@ def write_keyfile(keyfile_path: str): ...@@ -63,11 +66,6 @@ def write_keyfile(keyfile_path: str):
with open(keyfile_path, 'wb') as f: with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key) f.write(encryption_params.key)
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') @patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config): def test_load_with_tensorizer(mock_agent, tensorizer_config):
...@@ -90,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner): ...@@ -90,14 +88,15 @@ def test_can_deserialize_s3(vllm_runner):
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path, tensorizer_uri=tensorized_path,
num_readers=1, num_readers=1,
s3_endpoint="object.ord1.coreweave.com", s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(prompts,
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501 sampling_params)
# noqa: E501
assert deserialized_outputs assert deserialized_outputs
...@@ -117,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( ...@@ -117,18 +116,19 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
encryption_keyfile=key_path encryption_keyfile=key_path
) )
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing) config_for_serializing)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path) encryption_keyfile=key_path)
with vllm_runner( with vllm_runner(
model_ref, model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501 model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
# noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
...@@ -144,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, ...@@ -144,12 +144,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
serializer.write_module(hf_model.model) serializer.write_module(hf_model.model)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
num_readers=1, num_readers=1,
)) as loaded_hf_model: )) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy( deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens) prompts, max_tokens=max_tokens)
...@@ -171,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -171,21 +170,21 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
with vllm_runner( with vllm_runner(
model_ref, model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=TensorizerConfig( model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path, tensorizer_uri=model_path,
num_readers=1, num_readers=1,
), ),
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=1,
max_lora_rank=8, max_lora_rank=8,
max_cpu_loras=2, max_cpu_loras=2,
max_num_seqs=50, max_num_seqs=50,
max_model_len=1000, max_model_len=1000,
) as loaded_vllm_model: ) as loaded_vllm_model:
process_requests(loaded_vllm_model.model.llm_engine, test_prompts) process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
...@@ -193,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): ...@@ -193,10 +192,14 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def test_load_without_tensorizer_load_format(vllm_runner): def test_load_without_tensorizer_load_format(vllm_runner):
model = None
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner( model = vllm_runner(
model_ref, model_ref,
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
...@@ -206,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -206,7 +209,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model), serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path)) TensorizerConfig(tensorizer_uri=model_path))
model_loader_extra_config = { model_loader_extra_config = {
"tensorizer_uri": str(model_path), "tensorizer_uri": str(model_path),
...@@ -224,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -224,9 +227,9 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
client = server.get_client() client = server.get_client()
completion = client.completions.create(model=model_ref, completion = client.completions.create(model=model_ref,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
temperature=0.0) temperature=0.0)
assert completion.id is not None assert completion.id is not None
assert len(completion.choices) == 1 assert len(completion.choices) == 1
...@@ -237,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ...@@ -237,11 +240,15 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
def test_raise_value_error_on_invalid_load_format(vllm_runner): def test_raise_value_error_on_invalid_load_format(vllm_runner):
model = None
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_runner( model = vllm_runner(
model_ref, model_ref,
load_format="safetensors", load_format="safetensors",
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
del model
gc.collect()
torch.cuda.empty_cache()
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
...@@ -263,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner): ...@@ -263,22 +270,20 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner):
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
) )
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs") reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner, def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path): tmp_path):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model # record outputs from un-sharded un-tensorized model
base_model = vllm_runner( with vllm_runner(
model_ref, model_ref,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
) ) as base_model:
outputs = base_model.generate(prompts, sampling_params) outputs = base_model.generate(prompts, sampling_params)
base_model.model.llm_engine.model_executor.shutdown()
base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
# load model with two shards and serialize with encryption # load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors")) model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
...@@ -291,31 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner, ...@@ -291,31 +296,34 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tensorize_vllm_model( tensorize_vllm_model(
engine_args=EngineArgs( engine_args=EngineArgs(
model=model_ref, model=model_ref,
tensor_parallel_size=2, tensor_parallel_size=2,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
enforce_eager=True, enforce_eager=True,
), ),
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
) )
assert os.path.isfile(model_path % 0), "Serialization subprocess failed" assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed" assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) with vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
@retry_until_skip(3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
gc.collect()
torch.cuda.empty_cache()
model_ref = "facebook/opt-125m" model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors") model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path)) config = TensorizerConfig(tensorizer_uri=str(model_path))
...@@ -327,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): ...@@ -327,8 +335,10 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
assert is_vllm_tensorized(config) assert is_vllm_tensorized(config)
with vllm_runner(model_ref, with vllm_runner(model_ref,
load_format="tensorizer", load_format="tensorizer",
model_loader_extra_config=config) as loaded_vllm_model: model_loader_extra_config=config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 deserialized_outputs = loaded_vllm_model.generate(prompts,
sampling_params)
# noqa: E501
assert outputs == deserialized_outputs assert outputs == deserialized_outputs
...@@ -104,8 +104,10 @@ def test_rope_customization(): ...@@ -104,8 +104,10 @@ def test_rope_customization():
dtype="float16", dtype="float16",
seed=0, seed=0,
) )
assert getattr(longchat_model_config.hf_config, "rope_scaling", # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
None) == LONGCHAT_ROPE_SCALING assert all(
longchat_model_config.hf_config.rope_scaling.get(key) == value
for key, value in LONGCHAT_ROPE_SCALING.items())
assert longchat_model_config.max_model_len == 16384 assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig( longchat_model_config = ModelConfig(
......
import pytest
import torch
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("type_tuple", (
(-8, 7, scalar_types.int4),
(0, 15, scalar_types.uint4),
(-8, 7, scalar_types.uint4b8),
(-128, 127, scalar_types.uint8b128),
(-28., 28., scalar_types.float6_e3m2f),
(torch.int8, scalar_types.int8),
(torch.uint8, scalar_types.uint8),
(torch.float8_e5m2, scalar_types.float8_e5m2),
(torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
(torch.bfloat16, scalar_types.float16_e8m7),
(torch.float16, scalar_types.float16_e5m10),
),
ids=lambda x: str(x))
def test_scalar_type_min_max(type_tuple):
print(type_tuple)
if len(type_tuple) == 3:
min, max, t = type_tuple
else:
torch_type, t = type_tuple
if torch_type.is_floating_point:
min = torch.finfo(torch_type).min
max = torch.finfo(torch_type).max
else:
min = torch.iinfo(torch_type).min
max = torch.iinfo(torch_type).max
print(t, min, max, t.min(), t.max())
assert min == t.min()
assert max == t.max()
import functools
import os import os
import signal
import subprocess import subprocess
import sys import sys
import time import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import openai import openai
import ray import ray
...@@ -48,13 +50,14 @@ VLLM_PATH = Path(__file__).parent.parent ...@@ -48,13 +50,14 @@ VLLM_PATH = Path(__file__).parent.parent
class RemoteOpenAIServer: class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds
def __init__( def __init__(
self, self,
model: str, model: str,
cli_args: List[str], cli_args: List[str],
*, *,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True, auto_port: bool = True,
) -> None: ) -> None:
if auto_port: if auto_port:
...@@ -75,6 +78,8 @@ class RemoteOpenAIServer: ...@@ -75,6 +78,8 @@ class RemoteOpenAIServer:
# the current process might initialize cuda, # the current process might initialize cuda,
# to be safe, we should use spawn method # to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
env=env, env=env,
stdout=sys.stdout, stdout=sys.stdout,
...@@ -87,6 +92,11 @@ class RemoteOpenAIServer: ...@@ -87,6 +92,11 @@ class RemoteOpenAIServer:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate() self.proc.terminate()
try:
self.proc.wait(3)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()
def _wait_for_server(self, *, url: str, timeout: float): def _wait_for_server(self, *, url: str, timeout: float):
# run health check # run health check
...@@ -125,10 +135,21 @@ class RemoteOpenAIServer: ...@@ -125,10 +135,21 @@ class RemoteOpenAIServer:
) )
def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None):
""" """
Launch API server with two different sets of arguments and compare the Launch API server with two different sets of arguments/environments
results of the API calls. The arguments are after the model name. and compare the results of the API calls.
Args:
model: The model to test.
arg1: The first set of arguments to pass to the API server.
arg2: The second set of arguments to pass to the API server.
env1: The first set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
""" """
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
...@@ -136,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): ...@@ -136,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
prompt = "Hello, my name is" prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"] token_ids = tokenizer(prompt)["input_ids"]
results = [] results = []
for args in (arg1, arg2): for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model, args) as server: with RemoteOpenAIServer(model, args, env_dict=env) as server:
client = server.get_client() client = server.get_client()
# test models list # test models list
...@@ -178,6 +199,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): ...@@ -178,6 +199,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage": completion.usage, "usage": completion.usage,
}) })
# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
# test simple list # test simple list
batch = client.completions.create( batch = client.completions.create(
model=model, model=model,
...@@ -305,3 +357,43 @@ def wait_for_gpu_memory_to_clear(devices: List[int], ...@@ -305,3 +357,43 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
f'{dur_s=:.02f} ({threshold_bytes/2**30=})') f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5) time.sleep(5)
def fork_new_process_for_each_test(f):
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_singla_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
return wrapper
...@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
for _ in range(expected_bs - len(seq_lens)): for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1) seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.num_decode_tokens == len(seq_lens)
start_idx = 0 start_idx = 0
start_loc = [start_idx] start_loc = [start_idx]
for _ in context_lens: for _ in context_lens:
......
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
if TYPE_CHECKING or not core_C_available:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp
from dataclasses import dataclass
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
_finite_values_only: bool = False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""
nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@property
def size_bits(self):
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...
def is_floating_point(self):
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self):
"If the type is an integer type"
return self.exponent == 0
def has_bias(self):
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self):
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self):
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only
def __str__(self) -> str:
raise NotImplementedError
def __repr__(self) -> str:
raise NotImplementedError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
return cls(size_bits, size_bits, bias if bias else 0, False)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int):
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
nan_repr)
elif core_C_available:
try:
import vllm._core_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._core_C with %r", e)
ScalarType = torch.classes._core_C.ScalarType
import contextlib import contextlib
import functools import functools
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Union
import torch import torch
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
try: try:
from lmslim import quant_ops from lmslim import quant_ops
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
...@@ -17,12 +19,9 @@ try: ...@@ -17,12 +19,9 @@ try:
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
import vllm._moe_C
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
# ruff: noqa: F401 # ruff: noqa: F401
import vllm._punica_C import vllm._moe_C
def is_custom_op_supported(op_name: str) -> bool: def is_custom_op_supported(op_name: str) -> bool:
...@@ -264,10 +263,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -264,10 +263,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# marlin_24 # marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int, workspace: torch.Tensor, b_q_type: ScalarType,
size_n: int, size_k: int) -> torch.Tensor: size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m, workspace, b_q_type, size_m,
size_n, size_k) size_n, size_k)
...@@ -280,7 +279,7 @@ def cutlass_scaled_mm(a: torch.Tensor, ...@@ -280,7 +279,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype], out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
...@@ -323,16 +322,24 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, ...@@ -323,16 +322,24 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor,
b_scales: torch.Tensor, b_zeros: torch.Tensor, b_q_weight: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int, b_zeros: torch.Tensor,
size_n: int, size_k: int, is_k_full: bool, g_idx: torch.Tensor,
has_zp: bool) -> torch.Tensor: perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, num_bits, g_idx, perm, workspace, b_q_type,
size_m, size_n, size_k, is_k_full, size_m, size_n, size_k, is_k_full,
has_zp) has_zp, use_fp32_reduce)
# fp8 marlin # fp8 marlin
...@@ -348,7 +355,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -348,7 +355,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# def scaled_fp8_quant( # def scaled_fp8_quant(
# input: torch.Tensor, # input: torch.Tensor,
# scale: Optional[torch.Tensor] = None, # scale: Optional[torch.Tensor] = None,
# batch_dim_padding: Optional[int] = None, # num_token_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None, # scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False, # use_per_token_if_dynamic: bool = False,
# ) -> Tuple[torch.Tensor, torch.Tensor]: # ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -358,7 +365,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -358,7 +365,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# This function supports both static and dynamic quantization: If you # This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it, # provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows # the scale will be determined dynamically. The function also allows
# optional padding of the output tensor for downstream kernels that # optional padding of the output tensors for downstream kernels that
# will benefit from padding. # will benefit from padding.
# Args: # Args:
...@@ -366,7 +373,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -366,7 +373,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# scale: Optional scaling factor for the FP8 quantization # scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic # scale_ub: Optional upper bound for scaling factor in dynamic
# per token case # per token case
# batch_dim_padding: If specified, pad the first dimension # num_token_padding: If specified, pad the first dimension
# of the output to at least this value. # of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token # use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case. # in the dynamic quantization case.
...@@ -375,16 +382,16 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -375,16 +382,16 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and # Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor. # scaling factor.
# """ # """
# if batch_dim_padding: # # This code assumes batch_dim and num_tokens are flattened
# shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) # assert (input.ndim == 2)
# output = torch.empty(shape, # shape: Union[Tuple[int, int], torch.Size] = input.shape
# device=input.device, # if num_token_padding:
# dtype=torch.float8_e4m3fn) # shape = (max(num_token_padding, input.shape[0]), shape[1])
# else: # output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# if scale is None: # if scale is None:
# if use_per_token_if_dynamic: # if use_per_token_if_dynamic:
# scale = torch.empty((input.numel() // input.shape[-1], 1), # scale = torch.empty((shape[0], 1),
# device=input.device, # device=input.device,
# dtype=torch.float32) # dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant( # torch.ops._C.dynamic_per_token_scaled_fp8_quant(
...@@ -393,6 +400,8 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -393,6 +400,8 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# scale = torch.zeros(1, device=input.device, dtype=torch.float32) # scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) # torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else: # else:
# # num_token_padding not implemented for this case
# assert (scale.numel() == 1 or num_token_padding is None)
# torch.ops._C.static_scaled_fp8_quant(output, input, scale) # torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale # return output, scale
...@@ -428,6 +437,15 @@ def scaled_int8_quant( ...@@ -428,6 +437,15 @@ def scaled_int8_quant(
return output, input_scales return output, input_scales
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
# moe # moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
...@@ -467,10 +485,13 @@ def reshape_and_cache_flash( ...@@ -467,10 +485,13 @@ def reshape_and_cache_flash(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None: ) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping, value_cache, slot_mapping,
kv_cache_dtype) kv_cache_dtype, k_scale,
v_scale)
def copy_blocks(key_caches: List[torch.Tensor], def copy_blocks(key_caches: List[torch.Tensor],
...@@ -546,43 +567,6 @@ def register_graph_buffers(fa: int, handles: List[str], ...@@ -546,43 +567,6 @@ def register_graph_buffers(fa: int, handles: List[str],
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
# punica
def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
) -> None:
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
scale)
def dispatch_bgmv_low_level(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
h_in: int,
h_out: int,
y_offset: int,
) -> None:
torch.ops._punica_C.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
h_in,
h_out,
y_offset,
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456 # temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0 # TODO: remove this in v0.6.0
names_and_values = globals() names_and_values = globals()
......
...@@ -25,27 +25,33 @@ class ipex_ops: ...@@ -25,27 +25,33 @@ class ipex_ops:
x2 = x2.reshape(num, d) x2 = x2.reshape(num, d)
return x1, x2 return x1, x2
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out) ipex.llm.functional.silu_mul(x1, x2, out)
@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none") ipex.llm.functional.gelu_mul(x1, x2, out, "none")
@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x)) out.copy_(torch.nn.functional.gelu(x))
@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x)) out.copy_(torch.nn.functional.gelu(x))
# TODO add implementation of gelu_quick here # TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@staticmethod
def paged_attention_v1( def paged_attention_v1(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -78,12 +84,21 @@ class ipex_ops: ...@@ -78,12 +84,21 @@ class ipex_ops:
).view(num_kv_heads, ).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten() 1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace # todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(), torch.xpu.paged_attention_v1( # type: ignore
key_cache.view_as(value_cache), out,
value_cache, head_mapping, scale, query.contiguous(),
block_tables, context_lens, block_size, key_cache.view_as(value_cache),
max_context_len, alibi_slopes) value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def paged_attention_v2( def paged_attention_v2(
out: torch.Tensor, out: torch.Tensor,
exp_sum: torch.Tensor, exp_sum: torch.Tensor,
...@@ -119,13 +134,24 @@ class ipex_ops: ...@@ -119,13 +134,24 @@ class ipex_ops:
).view(num_kv_heads, ).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten() 1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace # todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out, torch.xpu.paged_attention_v2( # type: ignore
query.contiguous(), out,
key_cache.view_as(value_cache), exp_sum,
value_cache, head_mapping, block_tables, max_logits,
context_lens, scale, block_size, tmp_out,
max_context_len, alibi_slopes) query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def rotary_embedding( def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len] positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
...@@ -158,6 +184,7 @@ class ipex_ops: ...@@ -158,6 +184,7 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions) rotary_dim, is_neox, positions)
@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int, key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool, cos_sin_cache: torch.Tensor, is_neox: bool,
...@@ -189,17 +216,20 @@ class ipex_ops: ...@@ -189,17 +216,20 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions) rotary_dim, is_neox, positions)
@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None: epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp) out.copy_(tmp)
@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True) epsilon, True)
input.copy_(tmp) input.copy_(tmp)
@staticmethod
def varlen_attention( def varlen_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
...@@ -222,6 +252,7 @@ class ipex_ops: ...@@ -222,6 +252,7 @@ class ipex_ops:
softmax_scale, zero_tensors, softmax_scale, zero_tensors,
is_causal, return_softmax, gen_) is_causal, return_softmax, gen_)
@staticmethod
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
...@@ -240,8 +271,13 @@ class ipex_ops: ...@@ -240,8 +271,13 @@ class ipex_ops:
def copy_blocks(key_caches: List[torch.Tensor], def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor], value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping) torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
...@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]): ...@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_fn = deactivate_fn self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T): def _on_remove(self, key: Hashable, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key) logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key) self.deactivate_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)
...@@ -59,46 +59,46 @@ class AdapterModelManager(ABC): ...@@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
@property @property
@abstractmethod @abstractmethod
def adapter_slots(self): def adapter_slots(self) -> int:
... raise NotImplementedError
@property @property
@abstractmethod @abstractmethod
def capacity(self): def capacity(self) -> int:
... raise NotImplementedError
@abstractmethod @abstractmethod
def activate_adapter(self, adapter_id: int) -> bool: def activate_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool: def deactivate_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def add_adapter(self, adapter: Any) -> bool: def add_adapter(self, adapter: Any) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None: def set_adapter_mapping(self, mapping: Any) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_adapter(self, adapter_id: int) -> bool: def remove_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_all_adapters(self): def remove_all_adapters(self) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]: def get_adapter(self, adapter_id: int) -> Optional[Any]:
... raise NotImplementedError
@abstractmethod @abstractmethod
def list_adapters(self) -> Dict[int, Any]: def list_adapters(self) -> Dict[int, Any]:
... raise NotImplementedError
@abstractmethod @abstractmethod
def pin_adapter(self, adapter_id: int) -> bool: def pin_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
from abc import abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class AdapterRequest: class AdapterRequest(ABC):
""" """
Base class for adapter requests. Base class for adapter requests.
""" """
@property @property
@abstractmethod @abstractmethod
def adapter_id(self): def adapter_id(self) -> int:
... raise NotImplementedError
def __post_init__(self): def __post_init__(self) -> None:
if self.adapter_id < 1: if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}") raise ValueError(f"id must be > 0, got {self.adapter_id}")
......
...@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC): ...@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
@property @property
@abstractmethod @abstractmethod
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def set_active_adapters(self, requests: Set[Any], def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None: mapping: Optional[Any]) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def add_adapter(self, adapter_request: Any) -> bool: def add_adapter(self, adapter_request: Any) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_adapter(self, adapter_id: int) -> bool: def remove_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_all_adapters(self): def remove_all_adapters(self) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def list_adapters(self) -> Set[int]: def list_adapters(self) -> Set[int]:
... raise NotImplementedError
...@@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto", kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
assert blocksparse_params is not None assert blocksparse_params is not None
assert alibi_slopes is None, ValueError( assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.") "Alibi not support for blocksparse flash attention.")
assert sliding_window is None, ValueError( assert sliding_window is None, ValueError(
"sliding_window is invalid for blocksparse attention.") "sliding_window is invalid for blocksparse attention.")
assert logits_soft_cap is None, ValueError(
"logits_soft_cap is invalid for blocksparse attention.")
if "num_heads" not in blocksparse_params: if "num_heads" not in blocksparse_params:
blocksparse_params["num_heads"] = num_heads blocksparse_params["num_heads"] = num_heads
......
...@@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder( ...@@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder(
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
self.input_builder = input_builder self.input_builder = input_builder
self.runner = input_builder.runner self.runner = input_builder.runner
...@@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder( ...@@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder(
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool): chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append """Add a sequence group to the metadata. Specifically update/append
1. context length. 1. context length.
2. block table. 2. block table.
...@@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder( ...@@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
block_table = [] block_table = []
if inter_data.prefix_cache_hit: if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should # NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens. # include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id] block_table = block_tables[seq_id]
...@@ -272,23 +273,27 @@ class FlashAttentionMetadataBuilder( ...@@ -272,23 +273,27 @@ class FlashAttentionMetadataBuilder(
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int): cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.""" """Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list: for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data, self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled) self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens) max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
...@@ -297,7 +302,7 @@ class FlashAttentionMetadataBuilder( ...@@ -297,7 +302,7 @@ class FlashAttentionMetadataBuilder(
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size num_decode_tokens = batch_size
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
...@@ -397,9 +402,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -397,9 +402,11 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
assert blocksparse_params is None, ValueError( if blocksparse_params is not None:
"FlashAttention does not support block-sparse attention.") raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
...@@ -410,6 +417,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -410,6 +417,10 @@ class FlashAttentionImpl(AttentionImpl):
self.sliding_window = ((sliding_window, sliding_window) self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1)) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -478,6 +489,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -478,6 +489,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale,
v_scale,
) )
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
...@@ -515,6 +528,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -515,6 +528,7 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
) )
assert output[:num_prefill_tokens].shape == out.shape assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
...@@ -534,6 +548,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -534,6 +548,7 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
......
...@@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache # The data type of the paged kv cache
data_type: torch.dtype = None data_type: torch.dtype = None
device: torch.device = torch.device("cuda") device: torch.device = torch.device("cuda")
# Only used by gemma2 model
logits_soft_cap: Optional[float] = None
def __post_init__(self): def __post_init__(self):
# Refer to # Refer to
...@@ -135,13 +133,20 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -135,13 +133,20 @@ class FlashInferMetadata(AttentionMetadata):
return return
assert self.prefill_wrapper is not None assert self.prefill_wrapper is not None
assert self.query_start_loc is not None
assert self.paged_kv_indices is not None assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device) batch_size = self.query_start_loc.shape[0] - 1
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) assert batch_size >= 0
# The prefill stage does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward() self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr, self.query_start_loc, self.paged_kv_indptr,
...@@ -297,26 +302,38 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -297,26 +302,38 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if is_profile_run: if is_profile_run:
return return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
block_table = block_tables[seq_id] block_table = block_tables[seq_id]
self.paged_kv_indices.extend(block_table[:block_table_bound]) self._update_paged_kv_tensors(block_table, seq_len)
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound) def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
last_page_len = seq_len % self.block_size # If seq_len = 16, block_size = 16,
if last_page_len == 0: # block_table_bound is 1 with 1 valid block.
last_page_len = self.block_size # If seq_len = 15, block_size = 16,
self.paged_kv_last_page_len.append(last_page_len) # block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int): cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list: for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data, self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled) self.input_builder.chunked_prefill_enabled)
...@@ -331,7 +348,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -331,7 +348,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size num_decode_tokens = batch_size
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
...@@ -379,9 +396,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -379,9 +396,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if len(self.paged_kv_indptr) > 0: if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu", device="cpu",
...@@ -418,8 +432,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -418,8 +432,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
device=device, device=device,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph)
logits_soft_cap=logits_soft_cap)
class FlashInferImpl(AttentionImpl): class FlashInferImpl(AttentionImpl):
...@@ -434,6 +447,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -434,6 +447,7 @@ class FlashInferImpl(AttentionImpl):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -446,6 +460,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -446,6 +460,7 @@ class FlashInferImpl(AttentionImpl):
raise ValueError("Sliding window is not supported in FlashInfer.") raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1) self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -489,6 +504,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -489,6 +504,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache[:, 1], kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale,
v_scale,
) )
query = query.contiguous( query = query.contiguous(
...@@ -518,7 +535,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -518,7 +535,7 @@ class FlashInferImpl(AttentionImpl):
output = prefill_meta.prefill_wrapper.forward( output = prefill_meta.prefill_wrapper.forward(
query, query,
kv_cache, kv_cache,
logits_soft_cap=attn_metadata.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
causal=True) causal=True)
else: else:
assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata is not None
...@@ -527,5 +544,5 @@ class FlashInferImpl(AttentionImpl): ...@@ -527,5 +544,5 @@ class FlashInferImpl(AttentionImpl):
query, query,
kv_cache, kv_cache,
sm_scale=self.scale, sm_scale=self.scale,
logits_soft_cap=attn_metadata.logits_soft_cap) logits_soft_cap=self.logits_soft_cap)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
assert blocksparse_params is None, ValueError( if blocksparse_params is not None:
"Torch SPDA does not support block-sparse attention.") raise ValueError(
"IPEX backend does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("IPEX backend does not support logits_soft_cap.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
...@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops. import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
...@@ -55,8 +54,8 @@ class PallasMetadata(AttentionMetadata): ...@@ -55,8 +54,8 @@ class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills # Currently, input sequences can only contain all prefills
# or all decoding. # or all decoding.
block_tables: Optional[torch.Tensor] block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] context_lens: Optional[torch.Tensor] = None
@property @property
def prefill_metadata(self) -> Optional["PallasMetadata"]: def prefill_metadata(self) -> Optional["PallasMetadata"]:
...@@ -92,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -92,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -110,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -110,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("FP8 KV cache dtype is not supported.") raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None: if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.") raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4: if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.") raise NotImplementedError("TPU version must be 4 or higher.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment