Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
# SPDX-License-Identifier: Apache-2.0
import tempfile
from time import time
import pytest
from vllm import LLM, envs
......@@ -15,60 +12,6 @@ if not envs.VLLM_USE_V1:
)
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_compilation(model_name: str, monkeypatch):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with tempfile.TemporaryDirectory() as temp_dir:
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
# Compiling model init may still take some time, enforce_eager to skip.
llm = LLM(model_name,
enforce_eager=True,
max_num_seqs=16,
max_model_len=1024,
gpu_memory_utilization=0.5)
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
]
# First inference should be slow
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
_ = llm.generate(prompts, sampling_params)
run1 = time() - s
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
top_k=12,
min_p=0.8,
max_tokens=24)
s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s
# Much faster after compiling
assert run1 * 0.1 > run2
print("TIMES", run1, run2)
# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
s = time()
_ = llm.generate(prompts, sampling_params)
run3 = time() - s
assert run1 * 0.1 > run3
print("TIMES", run1, run3)
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
......@@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=True,
max_num_seqs=1,
max_model_len=64,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization=0.5)
llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=512,
max_num_batched_tokens=512)
prompts = [
"Write a short story about a robot that dreams for the first time."
]
......
# SPDX-License-Identifier: Apache-2.0
import math
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
probs = logits.softmax(dim=-1)
# Random top-p values between 0 and 1.
p = torch.rand((BATCH_SIZE, ))
# Set p=1 for ~50% of requests in the batch (top-p disabled).
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)
no_op_k = torch.tensor([VOCAB_SIZE])
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
k=no_op_k,
p=p)
# Verify that the masked logit's probability sums to at least p.
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
xm.mark_step()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
def test_topp_basic():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_topp_select_all():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0]))
xm.mark_step()
assert torch.allclose(logits.cpu(), result.cpu())
def test_topp_with_ties():
with torch.device(xm.xla_device()):
# Input has multiple math.log(0.3).
logits = torch.tensor(
[[math.log(0.3),
math.log(0.3),
math.log(0.3),
math.log(0.1)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([4]),
p=torch.tensor([0.2]))
xm.mark_step()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result = logits.clone().cpu()
expected_result[0, 3] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_both_topk_topp():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
# Set k=1 for the first batch.
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[0, 1] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
# SPDX-License-Identifier: Apache-2.0
import unittest.mock as mock
import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
_get_padded_token_len,
_get_paddings)
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict(
"sys.modules", {
"torch_xla": mock.MagicMock(),
"torch_xla.core.xla_model": mock.MagicMock(),
"torch_xla.runtime": mock.MagicMock(),
})
torch_xla_patcher.start()
# Mock the PallasAttentionBackend
pallas_attention_backend_patcher = mock.patch(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
pallas_attention_backend_patcher.start()
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16", # TPUs typically use bfloat16
seed=42,
)
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)
device = "xla:0" # Mocking TPU device
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
return TPUModelRunner(vllm_config, device)
@pytest.fixture(autouse=True, scope="session")
def cleanup_patches():
yield
torch_xla_patcher.stop()
pallas_attention_backend_patcher.stop()
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = []
num_scheduled_tokens = {}
total_num_scheduled_tokens = 0
for req_id in req_ids:
new_reqs.append(
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
num_computed_tokens=0,
lora_request=None,
))
num_scheduled_tokens[req_id] = 3
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
def _is_req_scheduled(model_runner, req_id: str) -> bool:
return req_id in model_runner.input_batch.req_id_to_index
def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
def test_update_states_new_request(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=0,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_no_changes(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
req_ids = ("req_0", "req_1")
# new reqs
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert _is_req_scheduled(model_runner, req_ids[1])
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])
def test_get_paddings():
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
def test_get_padded_token_len():
min_token_size, max_token_size, padding_gap = 16, 512, 64
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
assert _get_padded_token_len(paddings, 1) == 16
assert _get_padded_token_len(paddings, 16) == 16
assert _get_padded_token_len(paddings, 20) == 32
assert _get_padded_token_len(paddings, 300) == 320
assert _get_padded_token_len(paddings, 512) == 512
......@@ -18,5 +18,5 @@ if ! [ -x "$(command -v shellcheck)" ]; then
export PATH="$PATH:$(pwd)/shellcheck-${scversion}"
fi
# TODO - fix warnings in .buildkite/run-amd-test.sh
find . -name "*.sh" ".git" -prune -not -path "./.buildkite/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck -s bash "{}"'
# TODO - fix warnings in .buildkite/scripts/hardware_ci/run-amd-test.sh
find . -name "*.sh" ".git" -prune -not -path "./.buildkite/scripts/hardware_ci/run-amd-test.sh" -print0 | xargs -0 -I {} sh -c 'git check-ignore -q "{}" || shellcheck -s bash "{}"'
......@@ -4,9 +4,10 @@
# version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip
import os
import torch
# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import vllm.env_override # isort:skip # noqa: F401
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
......@@ -25,19 +26,6 @@ from vllm.sampling_params import SamplingParams
from vllm.version import __version__, __version_tuple__, __hcu_version__
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
__all__ = [
"__version__",
"__version_tuple__",
......
......@@ -148,6 +148,7 @@ def paged_attention_v2_with_mask(
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
query_start_loc: Optional[torch.Tensor],
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
......@@ -440,6 +441,7 @@ def paged_attention_v2_opt_tc_with_mask(
# scale: float,
# block_tables: torch.Tensor,
# seq_lens: torch.Tensor,
# query_start_loc: Optional[torch.Tensor],
# block_size: int,
# max_seq_len: int,
# alibi_slopes: Optional[torch.Tensor],
......@@ -450,8 +452,21 @@ def paged_attention_v2_opt_tc_with_mask(
# torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
# key_cache, value_cache, num_kv_heads,
# scale, block_tables, seq_lens,
# block_size, max_seq_len, alibi_slopes,
# kv_cache_dtype, k_scale, v_scale)
# query_start_loc, block_size, max_seq_len,
# alibi_slopes, kv_cache_dtype, k_scale,
# v_scale)
def mla_decode_kvcache_cpu(
out: torch.Tensor,
query: torch.Tensor,
kv_cache: torch.Tensor,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale,
block_tables, seq_lens)
# pos encoding ops
......@@ -792,7 +807,6 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# memory_format=torch.contiguous_format)
# if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
# @register_fake("_C::allspark_w8a16_gemm")
......@@ -810,13 +824,16 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# if hasattr(torch.ops._C, "ggml_dequantize"):
# @register_fake("_C::ggml_dequantize")
# def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
# m: torch.SymInt,
# n: torch.SymInt) -> torch.Tensor:
# def _ggml_dequantize_fake(
# W: torch.Tensor,
# quant_type: int,
# m: torch.SymInt,
# n: torch.SymInt,
# dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# return torch.empty((m, n), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_mul_mat_vec_a8")
# def _ggml_mul_mat_vec_a8_fake(
# W: torch.Tensor,
......@@ -995,6 +1012,9 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
cuda_device_capability)
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
def cutlass_sparse_compress(a: torch.Tensor) \
-> tuple[torch.Tensor, torch.Tensor]:
"""
......@@ -1085,6 +1105,56 @@ def cutlass_scaled_sparse_mm(
return out
def get_cutlass_moe_mm_data(
topk_ids: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor, output_permutation: torch.Tensor,
num_experts: int, n: int, k: int):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation, output_permutation,
num_experts, n, k)
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
b_tensors: torch.Tensor, a_scales: torch.Tensor,
b_scales: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor, a_strides: torch.Tensor,
b_strides: torch.Tensor, c_strides: torch.Tensor):
"""
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
The function executes fp8-quantized OUT = AB matrix multiplication.
- expert_offsets: Indices that mark at which token index each expert begins
its computation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
"""
torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales,
b_scales, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides)
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
......@@ -1452,9 +1522,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
# n: int) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
# dtype: Optional[torch.dtype]) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
def ggml_mul_mat_vec_a8(
......@@ -1579,7 +1649,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float) -> None:
gating_output: torch.Tensor) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output)
......@@ -1692,9 +1762,9 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, full_nvlink: bool) -> int:
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
full_nvlink)
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
......@@ -1760,6 +1830,7 @@ def write_cache_multi_layers(
value_caches, slot_mapping,
kv_cache_dtype)
def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
......
......@@ -187,15 +187,28 @@ class ipex_ops:
gen_: torch.Generator,
logits_soft_cap: float,
) -> None:
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(), seqlen_k.int(),
max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_,
logits_soft_cap)
if ipex.__version__.endswith("cpu"):
if logits_soft_cap != 0.0:
raise ValueError("IPEX CPU does not support logits_soft_cap")
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_)
else: # XPU build
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_, logits_soft_cap)
@staticmethod
def reshape_and_cache(
......
......@@ -10,8 +10,6 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.multimodal.video import sample_frames_from_video
from .base import get_cache_dir
......@@ -43,14 +41,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
for i in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
cap.release()
num_frames = num_frames if num_frames > 0 else total_frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for idx in range(total_frames):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_indices: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frames.append(frame)
frames = np.stack(frames)
frames = sample_frames_from_video(frames, num_frames)
if len(frames) < num_frames:
raise ValueError(f"Could not read enough frames from video file {path}"
f" (expected {num_frames} frames, got {len(frames)})")
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm._custom_ops as ops
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
class CPUMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "CPU_MLA"
@staticmethod
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
return CPUMLAMetadata
@staticmethod
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
return CPUMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["MLACommonState"]:
return MLACommonState
@staticmethod
def get_impl_cls() -> Type["CPUMLAImpl"]:
return CPUMLAImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
@dataclass
class CPUMLAMetadata(TorchSDPAMetadata):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor = None
# required by MLACommonImpl
is_profile_run: bool = False
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
assert not self.chunked_prefill, \
"chunked prefill is currently not supported"
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# metadata for prefill
if input_data.num_prefills > 0:
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
# for chunked-prefill
if self.chunked_prefill:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
prefill_block_tables = None
else:
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
prefill_block_tables = None
# metadata for decode
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
)
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
return CPUMLAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
input_positions=torch.tensor([self.input_data.input_positions]))
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl")
# states is implemented.
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CPUMLAImpl with FP8 KV cache not yet supported")
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
output = torch.empty_like(q)
ipex_ops.varlen_attention(
query=q,
key=k,
value=v_padded,
out=output,
seqlen_q=prefill_metadata.query_start_loc,
seqlen_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.max_query_len,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None,
logits_soft_cap=0.0,
)
# remove padding
output = output.view(-1, self.num_heads,
q.shape[-1])[..., :v.shape[-1]]
output = output.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
q = torch.cat([q_nope, q_pe], dim=-1)
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
# Run MQA
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor)
return self._v_up_proj_and_o_proj(o)
......@@ -204,7 +204,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
......@@ -212,18 +211,27 @@ from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
else:
merge_attn_states = None
triton_attention = None
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
from vllm.attention.ops.triton_flash_attention import triton_attention
try:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
......
......@@ -18,16 +18,13 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from vllm.logger import init_logger
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
class ROCmFlashAttentionBackend(AttentionBackend):
......@@ -804,9 +801,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
# use_custom = _use_rocm_custom_paged_attention(
# use_custom = use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len)
# decode_meta.max_decode_seq_len, self.sliding_window)
use_custom = False
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
......@@ -832,6 +829,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
out = output[num_prefill_tokens:]
else:
out = output
query_start_loc = None
ops.paged_attention_rocm(
out,
exp_sums,
......@@ -848,6 +847,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
query_start_loc,
block_size,
max_seq_len,
self.alibi_slopes,
......@@ -902,9 +902,8 @@ def _sdpa_attention(
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True,
enable_flash=False,
enable_mem_efficient=False):
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.MATH):
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
......@@ -917,14 +916,3 @@ def _sdpa_attention(
start = end
return output
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
......@@ -10,6 +10,9 @@ import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from .prefix_prefill import context_attention_fwd
......@@ -33,26 +36,26 @@ def kernel_paged_attention_2d(
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.constexpr, # int
query_stride_0: tl.constexpr, # int
query_stride_1: tl.constexpr, # int, should be equal to head_size
output_stride_0: tl.constexpr, # int
output_stride_1: tl.constexpr, # int, should be equal to head_size
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.constexpr, # int
stride_k_cache_1: tl.constexpr, # int
stride_k_cache_2: tl.constexpr, # int
stride_k_cache_3: tl.constexpr, # int
stride_k_cache_4: tl.constexpr, # int
stride_v_cache_0: tl.constexpr, # int
stride_v_cache_1: tl.constexpr, # int
stride_v_cache_2: tl.constexpr, # int
stride_v_cache_3: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
):
......@@ -212,6 +215,7 @@ def chunked_prefill_paged_decode(
block_table,
query_start_loc,
seq_lens,
max_seq_len,
max_query_len,
k_scale,
v_scale,
......@@ -240,6 +244,7 @@ def chunked_prefill_paged_decode(
b_loc=block_table,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_seq_len=max_seq_len,
max_input_len=max_query_len,
k_scale=k_scale,
v_scale=v_scale,
......@@ -275,43 +280,87 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)
kernel_paged_attention_2d[(
num_seqs,
num_kv_heads,
)](
output_ptr=output,
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
SLIDING_WINDOW=sliding_window,
x=key_cache.shape[4],
stride_k_cache_0=key_cache.stride(0),
stride_k_cache_1=key_cache.stride(1),
stride_k_cache_2=key_cache.stride(2),
stride_k_cache_3=key_cache.stride(3),
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
stride_v_cache_2=value_cache.stride(2),
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
)
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size,
num_queries_per_kv,
max_seq_len, sliding_window)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = query.shape[0]
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale=sm_scale,
block_tables=block_table,
seq_lens=seq_lens,
query_start_loc=query_start_loc,
block_size=block_size,
max_seq_len=max_seq_len,
alibi_slopes=alibi_slopes,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
)
else:
kernel_paged_attention_2d[(
num_seqs,
num_kv_heads,
)](
output_ptr=output,
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
SLIDING_WINDOW=sliding_window,
x=key_cache.shape[4],
stride_k_cache_0=key_cache.stride(0),
stride_k_cache_1=key_cache.stride(1),
stride_k_cache_2=key_cache.stride(2),
stride_k_cache_3=key_cache.stride(3),
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
stride_v_cache_2=value_cache.stride(2),
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
)
......@@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
def load_kv_tile_from_cache(
cur_k_tile,
cur_v_tile,
key_cache,
value_cache,
kv_cache,
block_tables,
large_k_tile_idx,
num_blocks_per_large_tile,
......@@ -169,8 +168,8 @@ def load_kv_tile_from_cache(
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
loaded = nl.load(key_cache[block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_k_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
# Transpose SBUF tensor using PE
......@@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
# load value cache
for load_idx in nl.affine_range(num_loads):
loaded = nl.load(value_cache[block_tables[load_idx, i_p,
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
......@@ -418,8 +417,7 @@ def flash_paged_attention(
query,
key,
value,
key_cache,
value_cache,
kv_cache,
block_tables,
mask,
softmax_scale=None,
......@@ -434,8 +432,7 @@ def flash_paged_attention(
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- key_cache: (num_blocks, n_kv_heads, block_size, d)
- value_cache: (num_blocks, n_kv_heads, block_size, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
......@@ -444,7 +441,7 @@ def flash_paged_attention(
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
- We use paged cache blocks (kv_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
......@@ -475,15 +472,13 @@ def flash_paged_attention(
b, h, d, seqlen_q = query.shape
B_D_SIZE = d
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
num_blocks, k_h, block_size, _ = key_cache.shape
_, num_blocks, k_h, block_size, _ = kv_cache.shape
q_h_per_k_h = h // k_h
assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
cache_shape = (num_blocks, k_h, block_size, d)
assert (tuple(key_cache.shape) == cache_shape
), f"{key_cache.shape=} mismatch, expect {cache_shape}"
assert (tuple(value_cache.shape) == cache_shape
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
cache_shape = (2, num_blocks, k_h, block_size, d)
assert (tuple(kv_cache.shape) == cache_shape
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
assert key is None or tuple(key.shape) == (
1,
k_h,
......@@ -580,13 +575,13 @@ def flash_paged_attention(
head_id=head_id,
)
# Flatten KV cache to be 2D for loading into SBUF
# Flatten KV cache to be 3D for loading into SBUF
new_cache_shape = (
2,
num_blocks * k_h * block_size_tiling_factor,
tiled_block_size * d,
)
key_cache = key_cache.reshape(new_cache_shape)
value_cache = value_cache.reshape(new_cache_shape)
kv_cache = kv_cache.reshape(new_cache_shape)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
......@@ -621,8 +616,7 @@ def flash_paged_attention(
load_kv_tile_from_cache(
cur_k_tile=cur_k_tile,
cur_v_tile=cur_v_tile,
key_cache=key_cache,
value_cache=value_cache,
kv_cache=kv_cache,
block_tables=block_tables_sbuf,
large_k_tile_idx=large_k_tile_idx,
num_blocks_per_large_tile=num_blocks_per_large_tile,
......@@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
query,
key,
value,
key_cache,
value_cache,
kv_cache,
block_table,
attn_mask,
n_kv_head=None,
......@@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- key_cache: (n_blocks, n_kv_heads, block_size, d)
- value_cache: (n_blocks, n_kv_heads, block_size, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
......@@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
for better DMA throughput
"""
if n_kv_head is None:
n_kv_head = key_cache.shape[1]
assert key_cache.shape[1] == n_kv_head
n_kv_head = kv_cache.shape[2]
assert kv_cache.shape[0] == 2
assert kv_cache.shape[2] == n_kv_head
if head_size is None:
head_size = key_cache.shape[-1]
head_size = kv_cache.shape[-1]
kwargs = dict(
query=query,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
kv_cache=kv_cache,
block_tables=block_table,
mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5),
......@@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""
......@@ -886,29 +877,29 @@ def reshape_and_cache(
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
value_cache (torch.Tensor): Value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
kv_cache (torch.Tensor): Key/value cache tensor with shape
(2, num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the key_cache and value_cache tensors in-place
None: Updates the kv_cache tensor in-place
"""
block_size = key_cache.size(2)
block_size = kv_cache.size(3)
n_kv_head = key.size(1)
# Calculate indices with explicit floor division
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = slot_mapping % block_size
# Create the head indices tensor
head_indices = torch.arange(n_kv_head, device=key.device)
# Update caches using index_put_
key_cache.index_put_(
(block_indices.unsqueeze(1),
torch.arange(key_cache.size(1),
device=key.device), block_offsets.unsqueeze(1)), key)
value_cache.index_put_(
(block_indices.unsqueeze(1),
torch.arange(value_cache.size(1),
device=value.device), block_offsets.unsqueeze(1)), value)
kv_cache.index_put_(
(torch.tensor([0], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), key)
kv_cache.index_put_(
(torch.tensor([1], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), value)
......@@ -435,6 +435,7 @@ class PagedAttention:
v_scale: torch.Tensor,
) -> torch.Tensor:
output = torch.empty_like(query)
max_seq_len = None
context_attention_fwd(
query,
key,
......@@ -447,6 +448,7 @@ class PagedAttention:
# query_start_loc is (batch_size + 1,)
query_start_loc,
seq_lens_tensor,
max_seq_len,
max_query_len,
k_scale,
v_scale,
......
......@@ -729,6 +729,7 @@ if triton.__version__ >= "2.1.0":
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
......@@ -756,7 +757,7 @@ if triton.__version__ >= "2.1.0":
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
......
......@@ -54,6 +54,15 @@ def merge_attn_states_kernel(
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
# If we see an inf assume FA2 and convert inf to -inf for consistency
# and correctness. Inf generally doesn't make sense in this context outside
# of undefined-behavior/FA2-case, so I think this a safe assumption.
p_lse = float('-inf') if p_lse == float('inf') else p_lse
s_lse = float('-inf') if s_lse == float('inf') else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
......
......@@ -219,7 +219,15 @@ async def async_request_deepspeed_mii(
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
output.generated_text = parsed_resp["text"][0]
if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][
"text"]
elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0]
else:
output.error = ("Unexpected response format: "
"neither 'choices' nor 'text' found")
output.success = False
output.success = True
else:
output.error = response.reason or ""
......
......@@ -7,9 +7,6 @@ On the server side, run one of the following commands:
--swap-space 16 \
--disable-log-requests
(TGI backend)
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
......@@ -52,9 +49,11 @@ try:
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, HuggingFaceDataset,
InstructCoderDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
......@@ -586,19 +585,39 @@ def main(args: argparse.Namespace):
return_prompt_formatted=True)
elif args.dataset_name == "hf":
# Choose between VisionArenaDataset
# and HuggingFaceDataset based on provided parameters.
dataset_class = (VisionArenaDataset if args.dataset_path
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
and args.hf_subset is None else HuggingFaceDataset)
# all following datasets are implemented from the
# HuggingFaceDataset base class
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_class = VisionArenaDataset
args.hf_split = "train"
args.hf_subset = None
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_class = InstructCoderDataset
args.hf_split = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ConversationDataset
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_class = AIMODataset
args.hf_split = "train"
else:
supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__()
for dataset_name in cls.SUPPORTED_DATASET_PATHS
])
raise ValueError(
f"Unsupported dataset path: {args.dataset_path}. "
"Huggingface dataset only supports dataset_path"
f" from one of following: {supported_datasets}. "
"Please consider contributing if you would "
"like to add support for additional dataset formats.")
input_requests = dataset_class(
dataset_path=args.dataset_path,
dataset_subset=args.hf_subset,
dataset_split=args.hf_split,
random_seed=args.seed,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
random_seed=args.seed,
output_len=args.hf_output_len,
)
......
......@@ -14,7 +14,8 @@ from typing import Any, Optional, Union
import numpy as np
import torch
import uvloop
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, InstructCoderDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
......@@ -347,6 +348,7 @@ def get_requests(args, tokenizer):
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
......@@ -364,18 +366,23 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.backend != "vllm-chat":
raise ValueError(
"hf datasets only are supported by vllm-chat backend")
# Choose between VisionArenaDataset and HuggingFaceDataset based on
# provided parameters.
dataset_cls = (VisionArenaDataset if args.dataset_path
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
and args.hf_subset is None else HuggingFaceDataset)
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
......@@ -509,9 +516,17 @@ def validate_args(args):
warnings.warn("--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2)
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
raise ValueError(
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
elif args.dataset_name == "hf":
if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
else:
raise ValueError(
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment