Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
......@@ -6,6 +6,7 @@ import pytest
from vllm.attention.layer import Attention
from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
......@@ -71,6 +72,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=PoolingParams(),
block_ids=([0], ), # block_ids should be tuple[list[int]]
num_computed_tokens=0,
lora_request=None,
......@@ -80,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
......@@ -159,7 +161,7 @@ def test_update_states_request_finished(model_runner):
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
......@@ -189,7 +191,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
......@@ -207,16 +209,16 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=([], ),
num_computed_tokens=0,
req_ids=[req_id],
resumed_from_preemption=[False],
new_token_ids=[[]],
new_block_ids=[([], )],
num_computed_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
......@@ -247,7 +249,7 @@ def test_update_states_no_changes(model_runner):
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
......@@ -282,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
......@@ -585,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
vllm_config = get_vllm_config()
vllm_config.model_config.max_model_len = 32000
vllm_config.scheduler_config.max_num_seqs = 1200
model_runner = get_model_runner(vllm_config)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert model_runner.num_reqs_most_model_len == 1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert model_runner.num_reqs_max_model_len == 524
......@@ -2,14 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Sequence
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
......@@ -18,18 +22,24 @@ VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
f"{current_platform.device_type}:{i}"
for i in range(min(current_platform.device_count(), 2))
]
MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1, obj2):
def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
......@@ -46,7 +56,7 @@ def _compare_objs(obj1, obj2):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
......@@ -55,13 +65,11 @@ def _compare_objs(obj1, obj2):
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns a tuple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
......@@ -70,13 +78,11 @@ def _remove_requests(
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_indices_to_remove_list = list(req_indices_to_remove)
req_indices_to_remove_list.sort(reverse=True)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove, req_indices_to_remove_list
return req_ids_to_remove
def _construct_expected_sampling_metadata(
......@@ -97,7 +103,6 @@ def _construct_expected_sampling_metadata(
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
......@@ -120,7 +125,6 @@ def _construct_expected_sampling_metadata(
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
......@@ -142,8 +146,6 @@ def _construct_expected_sampling_metadata(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p, dtype=torch.float, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
......@@ -162,13 +164,12 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
min_tokens=min_tokens,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessorManager(),
)
......@@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int):
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_inputs=[],
mm_positions=[],
block_ids=([], ),
......@@ -221,6 +223,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
......@@ -234,21 +238,22 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove, req_indices_to_remove = _remove_requests(
input_batch, batch_size, reqs)
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense(req_indices_to_remove)
input_batch.condense()
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
......@@ -286,10 +291,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
......@@ -311,6 +314,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
......@@ -337,7 +342,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
......@@ -350,9 +356,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
for req_index in range(batch_size):
req = reordered_reqs[req_index]
ref_input_batch.add_request(req, req_index)
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_sampling_metadata()
ref_input_batch.refresh_sampling_metadata()
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)
......@@ -4,10 +4,12 @@
import random
import pytest
import torch
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
......@@ -22,7 +24,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
BLOCK_SIZE = 16
NUM_BLOCKS = 10
DEVICE = "cuda"
DEVICE = current_platform.device_type
def initialize_kv_cache(runner: GPUModelRunner):
......@@ -122,6 +124,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=None,
block_ids=([0], ),
num_computed_tokens=0,
lora_request=None,
......@@ -131,7 +134,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
......@@ -170,7 +173,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner):
def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0"
# new req
......@@ -184,7 +187,7 @@ def test_update_states_new_request(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
def test_update_states_request_finished(model_runner, dist_init):
req_id = "req_0"
# new req
......@@ -197,7 +200,7 @@ def test_update_states_request_finished(model_runner):
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
......@@ -216,7 +219,7 @@ def test_update_states_request_finished(model_runner):
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
def test_update_states_request_resumed(model_runner, dist_init):
req_id = "req_0"
# new req
......@@ -229,7 +232,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
......@@ -247,16 +250,16 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=([], ),
num_computed_tokens=0,
req_ids=[req_id],
resumed_from_preemption=[False],
new_token_ids=[[]],
new_block_ids=([[0]], ),
num_computed_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
......@@ -276,7 +279,55 @@ def test_update_states_request_resumed(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_no_changes(model_runner):
def test_get_nans_in_logits(model_runner, dist_init):
req_ids = ("req_0", "req_1")
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
logits = torch.tensor([
[1.0, 2.0, 3.0],
[3.0, 2.0, 1.0],
], device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 1, "req_1": 2}
logits = torch.tensor([
[1.0, 2.0, 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 2}
result = model_runner._get_nans_in_logits(logits=None)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
], device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 1, 'req_1': 0}
logits = torch.tensor([
[float('nan'), float('nan'), 2.0],
[1.0, 2.0, 3.0],
[float('nan'), 2.0, 3.0],
],
device=DEVICE)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 2, 'req_1': 0}
def test_update_states_no_changes(model_runner, dist_init):
req_id = "req_0"
# new req
......@@ -289,7 +340,7 @@ def test_update_states_no_changes(model_runner):
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
......@@ -309,7 +360,7 @@ def test_update_states_no_changes(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
def test_update_states_request_unscheduled(model_runner, dist_init):
req_ids = ("req_0", "req_1")
# new reqs
......@@ -326,7 +377,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
......@@ -399,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
......@@ -427,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
......@@ -455,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
......@@ -483,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_without_kv_sharing():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
......@@ -550,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
......
......@@ -209,32 +209,32 @@ def test_multi_step_model_runner_input():
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
receieved_frozen_input = received_model_input.frozen_model_input
received_frozen_input = received_model_input.frozen_model_input
# Check that received copy has correct values.
assert isinstance(received_model_input, StatefulModelInput)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
assert received_frozen_input.input_tokens is not None
assert (received_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
assert receieved_frozen_input.input_positions is not None
assert (receieved_frozen_input.input_positions ==
assert received_frozen_input.input_positions is not None
assert (received_frozen_input.input_positions ==
frozen_model_input.input_positions).all()
assert receieved_frozen_input.multi_modal_kwargs is None
assert received_frozen_input.multi_modal_kwargs is None
assert (frozen_model_input.multi_modal_kwargs ==
frozen_model_input.multi_modal_kwargs)
assert receieved_frozen_input.lora_requests is None
assert (receieved_frozen_input.lora_requests ==
assert received_frozen_input.lora_requests is None
assert (received_frozen_input.lora_requests ==
frozen_model_input.lora_requests)
assert receieved_frozen_input.lora_mapping is None
assert received_frozen_input.lora_mapping is None
assert (
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
received_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
for field in dataclasses.fields(AttentionMetadata):
assert getattr(receieved_frozen_input.attn_metadata, field.name,
assert getattr(received_frozen_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
assert (received_frozen_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert receieved_frozen_input.sampling_metadata.seq_groups is None
assert received_frozen_input.sampling_metadata.seq_groups is None
# check non frozen fields
assert received_model_input.is_last_step == model_input.is_last_step
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Ensure we perform lazy loading in vllm/__init__.py.
i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard,
**except** for a short whitelist.
"""
from __future__ import annotations
import ast
import pathlib
import sys
from collections.abc import Iterable
from typing import Final
REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"
# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
"vllm.env_override",
})
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
".version",
})
def _is_internal(name: str | None, *, level: int = 0) -> bool:
if level > 0:
return True
if name is None:
return False
return name.startswith("vllm.") or name == "vllm"
def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)
def main() -> None:
source = INIT_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(INIT_PATH))
violations: list[tuple[int, str]] = []
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self._in_type_checking = False
def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(
test.value, ast.Name):
guard_is_type_checking = (test.value.id == "typing"
and test.attr == "TYPE_CHECKING")
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"
if guard_is_type_checking:
prev = self._in_type_checking
self._in_type_checking = True
for child in node.body:
self.visit(child)
self._in_type_checking = prev
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
if self._in_type_checking:
return
for alias in node.names:
module_name = alias.name
if _is_internal(
module_name) and module_name not in ALLOWED_IMPORTS:
violations.append((
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
))
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if _is_internal(
node.module, level=node.level
) and module_as_written not in ALLOWED_FROM_MODULES:
violations.append((
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
))
Visitor().visit(tree)
if violations:
_fail(violations)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
import regex as re
try:
import pathspec
except ImportError:
print(
"ERROR: The 'pathspec' library is required. "
"Install it with 'pip install pathspec'.",
file=sys.stderr)
sys.exit(2)
# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
# The pickle and cloudpickle modules are known to be unsafe when deserializing
# data from potentially untrusted parties. They have resulted in multiple CVEs
# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
# Before adding new uses of pickle/cloudpickle, please consider safer
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = set([
# pickle
'vllm/v1/serial_utils.py',
'vllm/v1/executor/multiproc_executor.py',
'vllm/multimodal/hasher.py',
'vllm/transformers_utils/config.py',
'vllm/model_executor/models/registry.py',
'tests/test_utils.py',
'tests/tokenization/test_cached_tokenizer.py',
'tests/model_executor/test_guided_processors.py',
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',
'benchmarks/kernels/benchmark_lora.py',
'benchmarks/kernels/benchmark_machete.py',
'benchmarks/fused_kernels/layernorm_rms_benchmarks.py',
'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py',
'benchmarks/cutlass_benchmarks/sparse_benchmarks.py',
# cloudpickle
'vllm/worker/worker_base.py',
'vllm/executor/mp_distributed_executor.py',
'vllm/executor/ray_distributed_executor.py',
'vllm/entrypoints/llm.py',
'tests/utils.py',
# pickle and cloudpickle
'vllm/utils/__init__.py',
'vllm/v1/serial_utils.py',
'vllm/v1/executor/multiproc_executor.py',
'vllm/transformers_utils/config.py',
'vllm/model_executor/models/registry.py',
'vllm/engine/multiprocessing/client.py',
'vllm/engine/multiprocessing/engine.py',
])
PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)")
def is_python_file(path):
return path.endswith('.py')
def scan_file(path):
with open(path, encoding='utf-8') as f:
for line in f:
if PICKLE_RE.match(line):
return True
return False
def load_gitignore(repo_root):
gitignore_path = os.path.join(repo_root, '.gitignore')
patterns = []
if os.path.exists(gitignore_path):
with open(gitignore_path, encoding='utf-8') as f:
patterns = f.read().splitlines()
# Always ignore .git directory
patterns.append('.git/')
return pathspec.PathSpec.from_lines('gitwildmatch', patterns)
def main():
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
spec = load_gitignore(repo_root)
bad_files = []
for dirpath, _, filenames in os.walk(repo_root):
for filename in filenames:
if not is_python_file(filename):
continue
abs_path = os.path.join(dirpath, filename)
rel_path = os.path.relpath(abs_path, repo_root)
# Skip ignored files
if spec.match_file(rel_path):
continue
if scan_file(abs_path) and rel_path not in ALLOWED_FILES:
bad_files.append(rel_path)
if bad_files:
print("\nERROR: The following files import 'pickle' or 'cloudpickle' "
"but are not in the allowed list:")
for f in bad_files:
print(f" {f}")
print("\nIf this is intentional, update the allowed list in "
"tools/check_pickle_imports.py.")
sys.exit(1)
sys.exit(0)
def test_regex():
test_cases = [
# Should match
("import pickle", True),
("import cloudpickle", True),
("import pickle as pkl", True),
("import cloudpickle as cpkl", True),
("from pickle import *", True),
("from cloudpickle import dumps", True),
("from pickle import dumps, loads", True),
("from cloudpickle import (dumps, loads)", True),
(" import pickle", True),
("\timport cloudpickle", True),
("from pickle import loads", True),
# Should not match
("import somethingelse", False),
("from somethingelse import pickle", False),
("# import pickle", False),
("print('import pickle')", False),
("import pickleas as asdf", False),
]
for i, (line, should_match) in enumerate(test_cases):
result = bool(PICKLE_RE.match(line))
assert result == should_match, (
f"Test case {i} failed: '{line}' "
f"(expected {should_match}, got {result})")
print("All regex tests passed.")
if __name__ == '__main__':
if '--test-regex' in sys.argv:
test_regex()
else:
main()
......@@ -2,51 +2,146 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from enum import Enum
SPDX_HEADER = (
class SPDXStatus(Enum):
"""SPDX header status enumeration"""
EMPTY = "empty" # empty __init__.py
COMPLETE = "complete"
MISSING_LICENSE = "missing_license" # Only has copyright line
MISSING_COPYRIGHT = "missing_copyright" # Only has license line
MISSING_BOTH = "missing_both" # Completely missing
FULL_SPDX_HEADER = (
"# SPDX-License-Identifier: Apache-2.0\n"
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project")
SPDX_HEADER_PREFIX = "# SPDX-License-Identifier:"
LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0"
COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501
def check_spdx_header(file_path):
with open(file_path, encoding='UTF-8') as file:
def check_spdx_header_status(file_path):
"""Check SPDX header status of the file"""
with open(file_path, encoding="UTF-8") as file:
lines = file.readlines()
if not lines:
# Empty file like __init__.py
return True
for line in lines:
if line.strip().startswith(SPDX_HEADER_PREFIX):
return True
return False
# Empty file
return SPDXStatus.EMPTY
# Skip shebang line
start_idx = 0
if lines and lines[0].startswith("#!"):
start_idx = 1
has_license = False
has_copyright = False
def add_header(file_path):
with open(file_path, 'r+', encoding='UTF-8') as file:
# Check all lines for SPDX headers (not just the first two)
for i in range(start_idx, len(lines)):
line = lines[i].strip()
if line == LICENSE_LINE:
has_license = True
elif line == COPYRIGHT_LINE:
has_copyright = True
# Determine status based on what we found
if has_license and has_copyright:
return SPDXStatus.COMPLETE
elif has_license and not has_copyright:
# Only has license line
return SPDXStatus.MISSING_COPYRIGHT
# Only has copyright line
elif not has_license and has_copyright:
return SPDXStatus.MISSING_LICENSE
else:
# Completely missing both lines
return SPDXStatus.MISSING_BOTH
def add_header(file_path, status):
"""Add or supplement SPDX header based on status"""
with open(file_path, "r+", encoding="UTF-8") as file:
lines = file.readlines()
file.seek(0, 0)
if lines and lines[0].startswith("#!"):
file.write(lines[0])
file.write(SPDX_HEADER + '\n')
file.writelines(lines[1:])
else:
file.write(SPDX_HEADER + '\n')
file.truncate()
if status == SPDXStatus.MISSING_BOTH:
# Completely missing, add complete header
if lines and lines[0].startswith("#!"):
# Preserve shebang line
file.write(lines[0])
file.write(FULL_SPDX_HEADER + "\n")
file.writelines(lines[1:])
else:
# Add header directly
file.write(FULL_SPDX_HEADER + "\n")
file.writelines(lines)
elif status == SPDXStatus.MISSING_COPYRIGHT:
# Only has license line, need to add copyright line
# Find the license line and add copyright line after it
for i, line in enumerate(lines):
if line.strip() == LICENSE_LINE:
# Insert copyright line after license line
lines.insert(
i + 1,
f"{COPYRIGHT_LINE}\n",
)
break
file.writelines(lines)
elif status == SPDXStatus.MISSING_LICENSE:
# Only has copyright line, need to add license line
# Find the copyright line and add license line before it
for i, line in enumerate(lines):
if line.strip() == COPYRIGHT_LINE:
# Insert license line before copyright line
lines.insert(i, f"{LICENSE_LINE}\n")
break
file.writelines(lines)
def main():
files_with_missing_header = []
"""Main function"""
files_missing_both = []
files_missing_copyright = []
files_missing_license = []
for file_path in sys.argv[1:]:
if not check_spdx_header(file_path):
files_with_missing_header.append(file_path)
status = check_spdx_header_status(file_path)
if files_with_missing_header:
if status == SPDXStatus.MISSING_BOTH:
files_missing_both.append(file_path)
elif status == SPDXStatus.MISSING_COPYRIGHT:
files_missing_copyright.append(file_path)
elif status == SPDXStatus.MISSING_LICENSE:
files_missing_license.append(file_path)
else:
continue
# Collect all files that need fixing
all_files_to_fix = (files_missing_both + files_missing_copyright +
files_missing_license)
if all_files_to_fix:
print("The following files are missing the SPDX header:")
for file_path in files_with_missing_header:
print(f" {file_path}")
add_header(file_path)
if files_missing_both:
for file_path in files_missing_both:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_BOTH)
if files_missing_copyright:
for file_path in files_missing_copyright:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_COPYRIGHT)
if files_missing_license:
for file_path in files_missing_license:
print(f" {file_path}")
add_header(file_path, SPDXStatus.MISSING_LICENSE)
sys.exit(1 if files_with_missing_header else 0)
sys.exit(1 if all_files_to_fix else 0)
if __name__ == "__main__":
......
......@@ -14,6 +14,12 @@ ALLOWED_LINES = {
"from vllm.triton_utils import tl, triton",
}
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
def is_allowed_file(current_file: str) -> bool:
return current_file in ALLOWED_FILES
def is_forbidden_import(line: str) -> bool:
stripped = line.strip()
......@@ -25,10 +31,14 @@ def parse_diff(diff: str) -> list[str]:
violations = []
current_file = None
current_lineno = None
skip_allowed_file = False
for line in diff.splitlines():
if line.startswith("+++ b/"):
current_file = line[6:]
skip_allowed_file = is_allowed_file(current_file)
elif skip_allowed_file:
continue
elif line.startswith("@@"):
match = re.search(r"\+(\d+)", line)
if match:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import multiprocessing
import os
import sys
from shutil import which
try:
# Try to get CUDA_HOME from PyTorch installation, which is the
# most reliable source of truth for vLLM's build.
from torch.utils.cpp_extension import CUDA_HOME
except ImportError:
print("Warning: PyTorch not found. "
"Falling back to CUDA_HOME environment variable.")
CUDA_HOME = os.environ.get("CUDA_HOME")
def get_python_executable():
"""Get the current Python executable, which is used to run this script."""
return sys.executable
def get_cpu_cores():
"""Get the number of CPU cores."""
return multiprocessing.cpu_count()
def generate_presets(output_path="CMakeUserPresets.json"):
"""Generates the CMakeUserPresets.json file."""
print("Attempting to detect your system configuration...")
# Detect NVCC
nvcc_path = None
if CUDA_HOME:
prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc")
if os.path.exists(prospective_path):
nvcc_path = prospective_path
print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: "
f"{nvcc_path}")
if not nvcc_path:
nvcc_path = which("nvcc")
if nvcc_path:
print(f"Found nvcc in PATH: {nvcc_path}")
if not nvcc_path:
nvcc_path_input = input(
"Could not automatically find 'nvcc'. Please provide the full "
"path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ")
nvcc_path = nvcc_path_input.strip()
print(f"Using NVCC path: {nvcc_path}")
# Detect Python executable
python_executable = get_python_executable()
if python_executable:
print(f"Found Python via sys.executable: {python_executable}")
else:
python_executable_prompt = (
"Could not automatically find Python executable. Please provide "
"the full path to your Python executable for vLLM development "
"(typically from your virtual environment, e.g., "
"/home/user/venvs/vllm/bin/python): ")
python_executable = input(python_executable_prompt).strip()
if not python_executable:
raise ValueError(
"Could not determine Python executable. Please provide it "
"manually.")
print(f"Using Python executable: {python_executable}")
# Get CPU cores
cpu_cores = get_cpu_cores()
nvcc_threads = min(4, cpu_cores)
cmake_jobs = max(1, cpu_cores // nvcc_threads)
print(f"Detected {cpu_cores} CPU cores. "
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.")
# Get vLLM project root (assuming this script is in vllm/tools/)
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."))
print(f"VLLM project root detected as: {project_root}")
# Ensure python_executable path is absolute or resolvable
if not os.path.isabs(python_executable) and which(python_executable):
python_executable = os.path.abspath(which(python_executable))
elif not os.path.isabs(python_executable):
print(f"Warning: Python executable '{python_executable}' is not an "
"absolute path and not found in PATH. CMake might not find it.")
cache_variables = {
"CMAKE_CUDA_COMPILER": nvcc_path,
"CMAKE_BUILD_TYPE": "Release",
"VLLM_PYTHON_EXECUTABLE": python_executable,
"CMAKE_INSTALL_PREFIX": "${sourceDir}",
"CMAKE_CUDA_FLAGS": "",
"NVCC_THREADS": str(nvcc_threads),
}
# Detect compiler cache
if which("sccache"):
print("Using sccache for compiler caching.")
for launcher in ("C", "CXX", "CUDA", "HIP"):
cache_variables[f"CMAKE_{launcher}_COMPILER_LAUNCHER"] = "sccache"
elif which("ccache"):
print("Using ccache for compiler caching.")
for launcher in ("C", "CXX", "CUDA", "HIP"):
cache_variables[f"CMAKE_{launcher}_COMPILER_LAUNCHER"] = "ccache"
else:
print("No compiler cache ('ccache' or 'sccache') found.")
configure_preset = {
"name": "release",
"binaryDir": "${sourceDir}/cmake-build-release",
"cacheVariables": cache_variables,
}
if which("ninja"):
print("Using Ninja generator.")
configure_preset["generator"] = "Ninja"
cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}"
else:
print("Ninja not found, using default generator. "
"Build may be slower.")
presets = {
"version":
6,
# Keep in sync with CMakeLists.txt and requirements/build.txt
"cmakeMinimumRequired": {
"major": 3,
"minor": 26,
"patch": 1
},
"configurePresets": [configure_preset],
"buildPresets": [{
"name": "release",
"configurePreset": "release",
"jobs": cmake_jobs,
}],
}
output_file_path = os.path.join(project_root, output_path)
if os.path.exists(output_file_path):
overwrite = input(
f"'{output_file_path}' already exists. Overwrite? (y/N): ").strip(
).lower()
if overwrite != 'y':
print("Generation cancelled.")
return
try:
with open(output_file_path, "w") as f:
json.dump(presets, f, indent=4)
print(f"Successfully generated '{output_file_path}'")
print("\nTo use this preset:")
print(
f"1. Ensure you are in the vLLM root directory: cd {project_root}")
print("2. Initialize CMake: cmake --preset release")
print("3. Build+install: cmake --build --preset release "
"--target install")
except OSError as e:
print(f"Error writing file: {e}")
if __name__ == "__main__":
generate_presets()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Generates specialized requirements files for nightly PyTorch testing.
This script reads the main test requirements input file (`requirements/test.in`)
and splits its content into two files:
1. `requirements/nightly_torch_test.txt`: Contains dependencies
except PyTorch-related.
2. `torch_nightly_test.txt`: Contains only PyTorch-related packages.
"""
input_file = "requirements/test.in"
output_file = "requirements/nightly_torch_test.txt"
# white list of packages that are not compatible with PyTorch nightly directly
# with pip install. Please add your package to this list if it is not compatible
# or make the dependency test fails.
white_list = ["torch", "torchaudio", "torchvision", "mamba_ssm"]
with open(input_file) as f:
lines = f.readlines()
skip_next = False
for line in lines:
if skip_next:
if line.startswith((" ", "\t")) or line.strip() == "":
continue
skip_next = False
if any(k in line.lower() for k in white_list):
skip_next = True
continue
......@@ -116,7 +116,7 @@ def ReadTargets(log, show_all):
# If ninja.exe is rudely halted then the .ninja_log file may be
# corrupt. Silently continue.
continue
start, end, _, name, cmdhash = parts # Ignore restat.
start, end, _, name, cmdhash = parts # Ignore restart.
# Convert from integral milliseconds to float seconds.
start = int(start) / 1000.0
end = int(end) / 1000.0
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Ensures all fields in a config dataclass have default values
and that each field has a docstring.
"""
import ast
import inspect
import sys
def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
Adapted from https://davidism.com/attribute-docstrings/
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
class ConfigValidator(ast.NodeVisitor):
def __init__(self):
...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id for d in node.decorator_list if (isinstance(d, ast.Name) and (
(id := d.id) == 'config' or id == 'dataclass')) or
(isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and
(id := d.func.id) == 'dataclass'))
]
if set(decorators) == {'config', 'dataclass'}:
validate_class(node)
elif set(decorators) == {'config'}:
fail(
f"Class {node.name} with config decorator must be a dataclass.",
node)
self.generic_visit(node)
def validate_class(class_node: ast.ClassDef):
attr_docs = get_attr_docs(class_node)
for stmt in class_node.body:
# A field is defined as a class variable that has a type annotation.
if isinstance(stmt, ast.AnnAssign):
# Skip ClassVar
# see https://docs.python.org/3/library/dataclasses.html#class-variables
if isinstance(stmt.annotation, ast.Subscript) and isinstance(
stmt.annotation.value,
ast.Name) and stmt.annotation.value.id == "ClassVar":
continue
if isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
if stmt.value is None:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a default value.", stmt)
if field_name not in attr_docs:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a docstring.", stmt)
if isinstance(stmt.annotation, ast.Subscript) and \
isinstance(stmt.annotation.value, ast.Name) \
and stmt.annotation.value.id == "Union" and \
isinstance(stmt.annotation.slice, ast.Tuple):
args = stmt.annotation.slice.elts
literal_args = [
arg for arg in args
if isinstance(arg, ast.Subscript) and isinstance(
arg.value, ast.Name) and arg.value.id == "Literal"
]
if len(literal_args) > 1:
fail(
f"Field '{field_name}' in {class_node.name} must "
"use a single "
"Literal type. Please use 'Literal[Literal1, "
"Literal2]' instead of 'Union[Literal1, Literal2]'"
".", stmt)
def validate_ast(tree: ast.stmt):
ConfigValidator().visit(tree)
def validate_file(file_path: str):
try:
print(f"validating {file_path} config dataclasses ", end="")
with open(file_path, encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=file_path)
validate_ast(tree)
except ValueError as e:
print(e)
SystemExit(2)
else:
print("✅")
def fail(message: str, node: ast.stmt):
raise ValueError(f"❌ line({node.lineno}): {message}")
def main():
for filename in sys.argv[1:]:
validate_file(filename)
if __name__ == "__main__":
main()
[files]
# these files may be written in non english words
extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*",
"benchmarks/sonnet.txt", "tests/lora/data/*", "build/*",
"vllm/third_party/*"]
ignore-hidden = true
ignore-files = true
ignore-dot = true
ignore-vcs = true
ignore-global = true
ignore-parent = true
[default]
binary = false
check-filename = false
check-file = true
unicode = true
ignore-hex = true
identifier-leading-digits = false
locale = "en"
extend-ignore-identifiers-re = ["NVML_*", ".*Unc.*", ".*_thw",
".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", ".*fo.*", ".*ba.*",
".*ot.*", ".*[Tt]h[rR].*"]
extend-ignore-words-re = []
extend-ignore-re = []
[default.extend-identifiers]
bbc5b7ede = "bbc5b7ede"
womens_doubles = "womens_doubles"
v_2nd = "v_2nd"
splitted_input = "splitted_input"
NOOPs = "NOOPs"
typ = "typ"
nin_shortcut = "nin_shortcut"
UperNetDecoder = "UperNetDecoder"
subtile = "subtile"
cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin"
SFOuput = "SFOuput"
# huggingface transformers repo uses these words
depthwise_seperable_out_channel = "depthwise_seperable_out_channel"
DepthWiseSeperableConv1d = "DepthWiseSeperableConv1d"
depthwise_seperable_CNN = "depthwise_seperable_CNN"
[default.extend-words]
iy = "iy"
tendencias = "tendencias"
# intel cpu features
tme = "tme"
dout = "dout"
Pn = "Pn"
arange = "arange"
[type.py]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.py.extend-identifiers]
arange = "arange"
NDArray = "NDArray"
EOFError = "EOFError"
[type.py.extend-words]
[type.cpp]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.cpp.extend-identifiers]
countr_one = "countr_one"
[type.cpp.extend-words]
[type.rust]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.rust.extend-identifiers]
flate2 = "flate2"
[type.rust.extend-words]
ser = "ser"
[type.lock]
extend-glob = []
check-file = false
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.lock.extend-identifiers]
[type.lock.extend-words]
[type.jl]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.jl.extend-identifiers]
[type.jl.extend-words]
modul = "modul"
egals = "egals"
usig = "usig"
egal = "egal"
[type.go]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.go.extend-identifiers]
flate = "flate"
[type.go.extend-words]
[type.css]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.css.extend-identifiers]
nd = "nd"
[type.css.extend-words]
[type.man]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.man.extend-identifiers]
Nd = "Nd"
[type.man.extend-words]
[type.cert]
extend-glob = []
check-file = false
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.cert.extend-identifiers]
[type.cert.extend-words]
[type.sh]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.sh.extend-identifiers]
stap = "stap"
ot = "ot"
[type.sh.extend-words]
[type.vimscript]
extend-glob = []
extend-ignore-identifiers-re = []
extend-ignore-words-re = []
extend-ignore-re = []
[type.vimscript.extend-identifiers]
windo = "windo"
[type.vimscript.extend-words]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
# The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip
import typing
# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import vllm.env_override # isort:skip # noqa: F401
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
"CompletionOutput": ".outputs:CompletionOutput",
"EmbeddingOutput": ".outputs:EmbeddingOutput",
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
"PoolingOutput": ".outputs:PoolingOutput",
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
"RequestOutput": ".outputs:RequestOutput",
"ScoringOutput": ".outputs:ScoringOutput",
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
}
if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput,
ClassificationRequestOutput, CompletionOutput,
EmbeddingOutput, EmbeddingRequestOutput,
PoolingOutput, PoolingRequestOutput,
RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
else:
def __getattr__(name: str) -> typing.Any:
from importlib import import_module
if name in MODULE_ATTRS:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
else:
raise AttributeError(
f'module {__package__} has no attribute {name}')
__all__ = [
"__version__",
......
......@@ -2,11 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import importlib
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.library
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -601,7 +599,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((1, row), dtype=X.dtype, device=W.device)
return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
......@@ -654,6 +652,20 @@ def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)
def cutlass_blockwise_scaled_grouped_mm(
output: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
scales_a: torch.Tensor,
scales_b: torch.Tensor,
problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor,
):
torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a,
scales_b, problem_sizes,
expert_offsets)
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor, alpha: torch.Tensor,
......@@ -712,10 +724,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
triton_scaled_mm)
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
......@@ -1234,6 +1244,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
......@@ -1265,7 +1276,12 @@ def scaled_fp8_quant(
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
......@@ -1273,13 +1289,12 @@ def scaled_fp8_quant(
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub)
output, input.contiguous(), scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
assert (scale.numel() == 1 or num_token_padding is None)
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
......@@ -1382,8 +1397,8 @@ def scaled_int8_quant(
dtype=torch.float32)
input_azp = None if symmetric else torch.empty_like(input_scales,
dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
input_azp)
torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(),
input_scales, input_azp)
return output, input_scales, input_azp
......@@ -1527,15 +1542,6 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
num_tokens_post_pad)
def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts,
block_size, sorted_token_ids,
experts_ids, num_tokens_post_pad)
def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
......@@ -1556,10 +1562,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output)
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indices,
gating_output)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
......@@ -1761,6 +1767,38 @@ def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
# quick all reduce
def init_custom_qr(rank: int,
world_size: int,
qr_max_size: Optional[int] = None) -> int:
return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size)
def qr_destroy(fa: int) -> None:
torch.ops._C_custom_ar.qr_destroy(fa)
def qr_all_reduce(fa: int,
inp: torch.Tensor,
out: torch.Tensor,
quant_level: int,
cast_bf2half: bool = False) -> None:
torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level,
cast_bf2half)
def qr_get_handle(fa: int) -> torch.Tensor:
return torch.ops._C_custom_ar.qr_get_handle(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
return torch.ops._C_custom_ar.qr_open_handles(fa, handles)
def qr_max_size() -> int:
return torch.ops._C_custom_ar.qr_max_size()
def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
......@@ -1825,10 +1863,59 @@ def flash_mla_with_kvcache(
return out, softmax_lse
# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
# seq_lens: torch.Tensor, page_table: torch.Tensor,
# scale: float) -> torch.Tensor:
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# return out
def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
scale: float) -> torch.Tensor:
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table, scale)
return out
if hasattr(torch.ops._C, "weight_packed_linear"):
@register_fake("_C::weight_packed_linear")
def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor,
bias: Optional[torch.Tensor],
is_vnni: bool) -> torch.Tensor:
return torch.empty((mat1.size(0), mat2.size(0)),
dtype=mat1.dtype,
device=mat2.device)
if hasattr(torch.ops._C, "fused_experts_cpu"):
@register_fake("_C::fused_experts_cpu")
def fused_experts_cpu_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool,
use_int8_w8a8: bool,
use_fp8_w8a16: bool,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
block_size: Optional[list[int]],
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
is_vnni: bool,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
@register_fake("_C::int8_scaled_mm_with_quant")
def int8_scaled_mm_with_quant_fake(
mat1: torch.Tensor,
mat2: torch.Tensor,
scales2: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: torch.dtype,
is_vnni: bool,
) -> torch.Tensor:
M = mat1.size(0)
N = mat2.size(0)
return torch.empty((M, N), dtype=out_dtype)
......@@ -228,6 +228,112 @@ class ipex_ops:
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
k_scale_float: float = 1.0,
v_scale_float: float = 1.0,
) -> None:
assert kv_cache_dtype == "auto"
# TODO: support FP8 kv cache.
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
window_size: Optional[list[int]] = None,
softcap: Optional[float] = 0.0,
cu_seqlens_k: Optional[torch.Tensor] = None,
# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
fa_version: int = 2,
q_descale=None,
k_descale=None,
v_descale=None,
num_splits=0,
):
if cu_seqlens_k is None:
# cu_seqlens_k is not used in ipex kernel.
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
cu_seqlens_k = torch.cat([
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
cu_seqlens_k
]).to(torch.int32)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
@staticmethod
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
) -> None:
logger.warning_once(
"get_scheduler_metadata is not implemented for ipex_ops, "
"returning None.")
return None
@staticmethod
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
......
......@@ -3,7 +3,7 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import ClassVar, Literal, Optional
from typing import Any, ClassVar, Literal, Optional
import cv2
import numpy as np
......@@ -77,6 +77,24 @@ def video_to_pil_images_list(path: str,
]
def video_get_metadata(path: str) -> dict[str, Any]:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
metadata = {
"total_num_frames": total_frames,
"fps": fps,
"duration": duration,
"video_backend": "opencv"
}
return metadata
VideoAssetName = Literal["baby_reading"]
......@@ -105,6 +123,12 @@ class VideoAsset:
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
@property
def metadata(self) -> dict[str, Any]:
video_path = download_video_asset(self.filename)
ret = video_get_metadata(video_path)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
......
......@@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:return: is fusion supported for this type of quantization
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
......@@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......
......@@ -65,7 +65,6 @@ class BlocksparseParams:
assert self.block_size > 0
assert self.local_blocks >= 0
assert self.vert_stride >= 1
assert self.num_heads % self.num_kv_heads == 0
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
......@@ -329,9 +328,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self.head_size = head_size
self.scale = float(scale)
self.alibi_slopes = alibi_slopes
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.local_blocks = self.blocksparse_params.local_blocks
......@@ -374,6 +372,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -388,6 +387,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......
......@@ -307,7 +307,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
......@@ -370,6 +369,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
......@@ -383,6 +384,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
(
query,
query_succ,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment