Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
...@@ -2,17 +2,13 @@ ...@@ -2,17 +2,13 @@
import pytest import pytest
from vllm.v1.structured_output.utils import ( from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features) has_xgrammar_unsupported_json_features)
@pytest.fixture @pytest.fixture
def unsupported_string_schemas(): def unsupported_string_schemas():
return [ return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{ {
"type": "string", "type": "string",
"format": "email" "format": "email"
...@@ -23,22 +19,6 @@ def unsupported_string_schemas(): ...@@ -23,22 +19,6 @@ def unsupported_string_schemas():
@pytest.fixture @pytest.fixture
def unsupported_integer_schemas(): def unsupported_integer_schemas():
return [ return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{ {
"type": "integer", "type": "integer",
"multipleOf": 120 "multipleOf": 120
...@@ -49,22 +29,6 @@ def unsupported_integer_schemas(): ...@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@pytest.fixture @pytest.fixture
def unsupported_number_schemas(): def unsupported_number_schemas():
return [ return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{ {
"type": "number", "type": "number",
"multipleOf": 120 "multipleOf": 120
...@@ -156,13 +120,28 @@ def supported_schema(): ...@@ -156,13 +120,28 @@ def supported_schema():
"type": "string", "type": "string",
"enum": ["sedan", "suv", "truck"] "enum": ["sedan", "suv", "truck"]
}, },
"car_brand": {
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
"short_description": { "short_description": {
"type": "string", "type": "string",
"maxLength": 50 "maxLength": 50
}, },
"mileage": {
"type": "number",
"minimum": 0,
"maximum": 1000000
},
"model_year": {
"type": "integer",
"exclusiveMinimum": 1900,
"exclusiveMaximum": 2100
},
"long_description": { "long_description": {
"type": "string", "type": "string",
"minLength": 50 "minLength": 50,
"maxLength": 2000
}, },
"address": { "address": {
"type": "object", "type": "object",
......
...@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind): ...@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so # the engines only synchronize stopping every N steps so
# allow a small amount of time here. # allow a small amount of time here.
for _ in range(10): for _ in range(10):
if core_client.num_engines_running == 0: if not core_client.engines_running:
break break
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
assert core_client.num_engines_running == 0 assert not core_client.engines_running
assert not core_client.reqs_in_flight assert not core_client.reqs_in_flight
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import msgspec
import numpy as np import numpy as np
import torch import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
...@@ -26,6 +32,7 @@ class MyType: ...@@ -26,6 +32,7 @@ class MyType:
large_f_contig_tensor: torch.Tensor large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor large_non_contig_tensor: torch.Tensor
empty_tensor: torch.Tensor
def test_encode_decode(): def test_encode_decode():
...@@ -41,6 +48,10 @@ def test_encode_decode(): ...@@ -41,6 +48,10 @@ def test_encode_decode():
torch.rand((1, 10), dtype=torch.float32), torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64), torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
], ],
numpy_array=np.arange(512), numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33), unrecognized=UnrecognizedType(33),
...@@ -48,9 +59,10 @@ def test_encode_decode(): ...@@ -48,9 +59,10 @@ def test_encode_decode():
large_f_contig_tensor=torch.rand(1024, 4).t(), large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
) )
encoder = MsgpackEncoder() encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType) decoder = MsgpackDecoder(MyType)
encoded = encoder.encode(obj) encoded = encoder.encode(obj)
...@@ -58,7 +70,7 @@ def test_encode_decode(): ...@@ -58,7 +70,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers # There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes. # + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline. # The two small tensors are encoded inline.
assert len(encoded) == 6 assert len(encoded) == 8
decoded: MyType = decoder.decode(encoded) decoded: MyType = decoder.decode(encoded)
...@@ -70,7 +82,7 @@ def test_encode_decode(): ...@@ -70,7 +82,7 @@ def test_encode_decode():
encoded2 = encoder.encode_into(obj, preallocated) encoded2 = encoder.encode_into(obj, preallocated)
assert len(encoded2) == 6 assert len(encoded2) == 8
assert encoded2[0] is preallocated assert encoded2[0] is preallocated
decoded2: MyType = decoder.decode(encoded2) decoded2: MyType = decoder.decode(encoded2)
...@@ -78,6 +90,97 @@ def test_encode_decode(): ...@@ -78,6 +90,97 @@ def test_encode_decode():
assert_equal(decoded2, obj) assert_equal(decoded2, obj)
class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]
def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(20000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 6
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalSharedField(4))
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 8
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert total_len >= 14235 and total_len <= 14275
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all(nested_equal(x, y) for x, y in zip(a, b))
def assert_equal(obj1: MyType, obj2: MyType): def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1) assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string assert obj1.a_string == obj2.a_string
...@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType): ...@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2.small_non_contig_tensor) obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor, assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor) obj2.large_non_contig_tensor)
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
...@@ -22,6 +22,7 @@ MODELS = [ ...@@ -22,6 +22,7 @@ MODELS = [
] ]
TENSOR_PARALLEL_SIZES = [1] TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance # TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4] # TENSOR_PARALLEL_SIZES = [1, 4]
...@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1] ...@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic( def test_basic(
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
tensor_parallel_size: int, tensor_parallel_size: int,
max_num_seqs: int,
) -> None: ) -> None:
prompt = "The next numbers of the sequence " + ", ".join( prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:" str(i) for i in range(1024)) + " are:"
...@@ -51,9 +54,9 @@ def test_basic( ...@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to # Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt # actually test chunked prompt
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
max_model_len=8196, max_model_len=8192,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
max_num_seqs=16, max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model: tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
......
# SPDX-License-Identifier: Apache-2.0
import openai
import pytest
from vllm import envs
from vllm.multimodal.utils import encode_image_base64, fetch_image
from vllm.platforms import current_platform
from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS
from ...utils import RemoteOpenAIServer
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.fixture(scope="session")
def base64_encoded_image() -> dict[str, str]:
return {
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}
@pytest.mark.asyncio
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str,
str]):
def whats_in_this_image_msg(b64):
return [{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}]
server_args = [
"--max-model-len",
"1024",
"--max-num-seqs",
"16",
"--gpu-memory-utilization",
"0.95",
"--trust-remote-code",
"--max-num-batched-tokens",
"576",
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input",
"--chat-template",
"examples/template_llava.jinja"
]
# Server will pre-compile on first startup (takes a long time).
with RemoteOpenAIServer(model_name, server_args,
max_wait_seconds=600) as remote_server:
client: openai.AsyncOpenAI = remote_server.get_async_client()
# Other requests now should be much faster
for image_url in TEST_IMAGE_URLS:
image_base64 = base64_encoded_image[image_url]
chat_completion_from_base64 = await client.chat.completions\
.create(
model=model_name,
messages=whats_in_this_image_msg(image_base64),
max_completion_tokens=24,
temperature=0.0)
result = chat_completion_from_base64
assert result
choice = result.choices[0]
assert choice.finish_reason == "length"
message = choice.message
message = result.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random
import pytest import pytest
from vllm import LLM, envs from vllm import LLM, envs
...@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str): ...@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param. # Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42) sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params) output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK/P
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12),
top_p=random.random()) for _ in range(B)
]
# Make sure first two reqs have the same K/P
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
...@@ -5,7 +5,8 @@ import pytest ...@@ -5,7 +5,8 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)
if not current_platform.is_tpu(): if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True) pytest.skip("This test needs a TPU.", allow_module_level=True)
...@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024 ...@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6 TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p(): def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()): with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33) xm.set_rng_state(seed=33)
......
...@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[], mm_inputs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
...@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
def test_get_paddings(): def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64 min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 256, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap) padding_gap)
assert actual_paddings == expected_paddings assert actual_paddings == expected_paddings
......
...@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return CachedRequestState( return CachedRequestState(
req_id=f"req_id_{req_id_suffix}", req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
mm_inputs=[], mm_inputs=[],
mm_positions=[], mm_positions=[],
......
...@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[], mm_inputs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
......
...@@ -1202,6 +1202,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, ...@@ -1202,6 +1202,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states, pad_slot_id) ssm_states, pad_slot_id)
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor,
rows_per_block: int) -> torch.Tensor:
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
return out
# moe # moe
def moe_sum(input: torch.Tensor, output: torch.Tensor): def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output) torch.ops._moe_C.moe_sum(input, output)
...@@ -1251,6 +1271,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, ...@@ -1251,6 +1271,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output) token_expert_indicies, gating_output)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor, moe_block_size: int,
top_k: int, mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe") @register_fake("_moe_C::marlin_gemm_moe")
...@@ -1269,6 +1312,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): ...@@ -1269,6 +1312,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype=a.dtype, dtype=a.dtype,
device=a.device) device=a.device)
@register_fake("_moe_C::moe_wna16_marlin_gemm")
def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor,
moe_block_size: int, top_k: int,
mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int, is_k_full: bool,
use_atomic_add: bool, use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.empty((size_m * top_k, size_n),
dtype=input.dtype,
device=input.device)
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
...@@ -1464,4 +1530,13 @@ def flash_mla_with_kvcache( ...@@ -1464,4 +1530,13 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
) )
return out, softmax_lse return out, softmax_lse
\ No newline at end of file
def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
scale: float) -> torch.Tensor:
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table, scale)
return out
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Literal from typing import Literal, Optional
import cv2 import cv2
import numpy as np import numpy as np
...@@ -10,8 +10,15 @@ import numpy.typing as npt ...@@ -10,8 +10,15 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from vllm.utils import PlaceholderModule
from .base import get_cache_dir from .base import get_cache_dir
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
@lru_cache @lru_cache
def download_video_asset(filename: str) -> str: def download_video_asset(filename: str) -> str:
...@@ -85,3 +92,12 @@ class VideoAsset: ...@@ -85,3 +92,12 @@ class VideoAsset:
video_path = download_video_asset(self.name) video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames) ret = video_to_ndarrays(video_path, self.num_frames)
return ret return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.name)
return librosa.load(video_path, sr=sampling_rate)[0]
...@@ -77,6 +77,10 @@ class AttentionBackend(ABC): ...@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def swap_blocks( def swap_blocks(
...@@ -237,6 +241,7 @@ class AttentionLayer(Protocol): ...@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
_v_scale: torch.Tensor _v_scale: torch.Tensor
_k_scale_float: float _k_scale_float: float
_v_scale_float: float _v_scale_float: float
_prob_scale: torch.Tensor
def forward( def forward(
self, self,
......
...@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import ( ...@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache) flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert ( assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with " "key/v_scale is only supported in FlashAttention 3 with "
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import dataclasses import dataclasses
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
...@@ -48,6 +49,9 @@ if TYPE_CHECKING: ...@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend): ...@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
...@@ -128,12 +140,10 @@ def get_per_layer_parameters( ...@@ -128,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`. to use during `plan`.
""" """
layers = vllm_config.compilation_config.static_forward_context layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {} per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl impl = layer.impl
assert isinstance(impl, FlashInferImpl) assert isinstance(impl, FlashInferImpl)
...@@ -187,7 +197,8 @@ class FlashInferState(AttentionState): ...@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config() self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
...@@ -197,10 +208,15 @@ class FlashInferState(AttentionState): ...@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
device=self.runner.device) device=self.runner.device)
return self._workspace_buffer return self._workspace_buffer
def get_kv_cache_layout(self):
if self._kv_cache_layout is None:
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
return self._kv_cache_layout
def _get_prefill_wrapper(self): def _get_prefill_wrapper(self):
if self._prefill_wrapper is None: if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD") self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper return self._prefill_wrapper
def _get_decode_wrapper(self): def _get_decode_wrapper(self):
...@@ -213,7 +229,7 @@ class FlashInferState(AttentionState): ...@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
num_qo_heads // num_kv_heads > 4) num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
"NHD", self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores) use_tensor_cores=use_tensor_cores)
return self._decode_wrapper return self._decode_wrapper
...@@ -274,7 +290,8 @@ class FlashInferState(AttentionState): ...@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
self._graph_decode_wrapper = \ self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD", self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores) use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"): if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
...@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config() self.vllm_config = self.runner.vllm_config
def prepare(self): def prepare(self):
self.slot_mapping: List[int] = [] self.slot_mapping: List[int] = []
...@@ -1005,6 +1022,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1005,6 +1022,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output: Optional[torch.Tensor] = None prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill # We will use flash attention for prefill
# when kv_cache is not provided. # when kv_cache is not provided.
...@@ -1036,7 +1054,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1036,7 +1054,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output = prefill_meta.prefill_wrapper.run( prefill_output = prefill_meta.prefill_wrapper.run(
query, query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
) )
...@@ -1051,7 +1069,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1051,7 +1069,7 @@ class FlashInferImpl(AttentionImpl):
decode_output = decode_meta.decode_wrapper.run( decode_output = decode_meta.decode_wrapper.run(
decode_query, decode_query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
) )
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
############################################################################### ###############################################################################
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, from vllm_hpu_extension.flags import enabled_flags
VLLMKVCache) from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionLayer,
...@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.block2batch_matmul = Matmul() self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache() self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache() self.v_cache = VLLMKVCache()
ops.pa_impl = ops.pa self.fused_scaled_dot_product_attention = kernels.fsdpa()
self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
self.prefill_impl = 'flex'
if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
...@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', if self.prefill_impl == 'fsdpa':
'0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \ assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!' 'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes: if head_size not in supported_head_sizes:
...@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.") f"Supported head sizes are: {supported_head_sizes}.")
if attn_type != AttentionType.DECODER: self.attn_type = attn_type
if self.attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
...@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape _, seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt: key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1)) key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None: if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache = HPUPagedAttention.split_kv_cache( key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
...@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
query_shape = (batch_size, seq_len, self.num_heads, self.head_size) query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size) self.head_size)
attn_bias = attn_metadata.attn_bias
if attn_bias is not None and self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
out = ops.prompt_attention( out = ops.prompt_attention(
query.view(query_shape), impl=self.prefill_impl,
key.view(kv_shape), query=query.view(query_shape),
value.view(kv_shape), key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, valid_seq_lengths=attn_metadata.seq_lens_tensor,
scale=self.scale, **self.common_attention_args())
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
output = out.reshape(batch_size, seq_len, hidden_size) output = out.reshape(batch_size, seq_len, hidden_size)
else: else:
# Decoding run. # Decoding run.
...@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list=attn_metadata.block_list, block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping, block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias, block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups, block_groups=attn_metadata.block_groups,
scale=self.scale, **self.common_attention_args())
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
}
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
......
...@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache, value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale_float,
layer._v_scale, layer._v_scale_float,
) )
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
...@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale_float,
layer._v_scale, layer._v_scale_float,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale_float,
layer._v_scale, layer._v_scale_float,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -206,6 +206,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -206,6 +206,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
...@@ -215,7 +216,7 @@ from vllm.multimodal import MultiModalPlaceholderMap ...@@ -215,7 +216,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
# from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON: if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.triton_flash_attention import triton_attention
...@@ -712,12 +713,24 @@ class MLACommonMetadata(AttentionMetadata): ...@@ -712,12 +713,24 @@ class MLACommonMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
self._ops_advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions)
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops.advance_step_flashattn(num_seqs=num_seqs, ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries, num_queries=num_queries,
block_size=block_size, block_size=block_size,
input_tokens=model_input.input_tokens, input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions, input_positions=input_positions,
seq_lens=self.seq_lens_tensor, seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping,
block_tables=self.block_tables) block_tables=self.block_tables)
...@@ -728,6 +741,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): ...@@ -728,6 +741,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
BLOCK_TABLE_EXTENDER: list[list[int]] = []
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder self.input_builder = input_builder
...@@ -878,8 +892,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): ...@@ -878,8 +892,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables( block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables) num_seqs, self.block_tables)
else: else:
...@@ -1044,8 +1060,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1044,8 +1060,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.q_proj = q_proj self.q_proj = q_proj
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.triton_fa_func = triton_attention
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3 # latter has an additional parameter to control FA2 vs FA3
...@@ -1058,6 +1074,77 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1058,6 +1074,77 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse:
attn_out = self.triton_fa_func(
q,
k,
maybe_padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest = None
if isinstance(attn_out, tuple):
attn_out, *rest = attn_out
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
assert rest is not None
return attn_out, rest[0]
return attn_out
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
...@@ -1190,40 +1277,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1190,40 +1277,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad attn_output, attn_softmax_lse = \
# out v with 0s to match the qk head dim self._flash_attn_varlen_diff_headdims(
v_padded = torch.nn.functional.pad(v, q=q,
[0, q.shape[-1] - v.shape[-1]], k=k,
value=0) v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
if is_vllm_fa: cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
attn_output, attn_softmax_lse = self.flash_attn_varlen_func( max_seqlen_q=prefill_metadata.max_query_len,
q=q, max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
k=k, softmax_scale=self.scale,
v=v_padded, causal=False, # Context is unmasked
cu_seqlens_q=prefill_metadata.query_start_loc, return_softmax_lse=True,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], )
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
if output is None: if output is None:
output = attn_output output = attn_output
...@@ -1266,58 +1332,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1266,58 +1332,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out output = self._flash_attn_varlen_diff_headdims(
# v with 0s to match the qk head dim q=q,
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], k=k,
value=0) v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: cu_seqlens_k=prefill_metadata.query_start_loc,
output = self.triton_fa_func( max_seqlen_q=prefill_metadata.max_prefill_seq_len,
q, max_seqlen_k=prefill_metadata.max_prefill_seq_len,
k, softmax_scale=self.scale,
v_padded, causal=True,
None, return_softmax_lse=has_context,
prefill_metadata.query_start_loc, )
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
if has_context: if has_context:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2 # ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata)
...@@ -1330,12 +1360,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1330,12 +1360,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse, suffix_lse=suffix_lse,
) )
# slice by `:v.shape[-1]` in order to remove v headdim padding return self.o_proj(output.flatten(start_dim=-2))[0]
output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
......
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Type, Union
import torch
import vllm._custom_ops as ops
import vllm.envs as envs
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
MLACommonState)
from vllm.attention.backends.utils import (compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
get_aiter_mla_metadata)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@staticmethod
def get_impl_cls() -> Type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["AiterMLAState"]:
return AiterMLAState
@dataclass
class AiterMLAMetadata(MLACommonMetadata):
# The following 4 tensors are for current version of AITER MLA
block_table_bound: Optional[torch.Tensor] = None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self):
prefill_metadata = super().prefill_metadata
self._cached_prefill_metadata = prefill_metadata
if prefill_metadata is not None:
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
prefill_metadata.paged_kv_indices = self.paged_kv_indices
prefill_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
prefill_metadata.block_table_bound = self.block_table_bound
# update the cache
self._cached_prefill_metadata = self.__class__(
**prefill_metadata.__dict__)
return self._cached_prefill_metadata
@property
def decode_metadata(self):
decode_metadata = super().decode_metadata
self._cached_decode_metadata = decode_metadata
if decode_metadata is not None:
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
decode_metadata.paged_kv_indices = self.paged_kv_indices
decode_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
decode_metadata.block_table_bound = self.block_table_bound
# update the cache
self._cached_decode_metadata = self.__class__(
**decode_metadata.__dict__)
return self._cached_decode_metadata
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
block_table_bound=self.block_table_bound)
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
super().__init__(input_builder)
assert self.runner.model_config.max_model_len == 32768,\
"AITER MLA requires max model len to be set to 32768"
assert self.block_size == 1, "AITER MLA requires only block size 1."
def prepare(self):
super().prepare()
self.paged_kv_indices: list[int] = []
self.paged_kv_indptr: list[int] = [0]
self.paged_kv_last_page_lens: list[int] = []
self.total_blocks = 0
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block, input_positions) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
inter_data.input_positions):
self.input_positions.extend(input_positions)
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
if is_profile_run:
return
# Update paged_kv_* tensors only for non-profile run
block_table = block_tables[seq_id]
self._update_paged_kv_tensors(block_table, seq_len)
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_lens.append(last_page_len)
def build(self, seq_lens: list[int], query_lens: list[int],
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
batch_size)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
if use_captured_graph:
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
# For current version of AITER MLA
if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device=device,
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
device=device,
dtype=torch.int)
paged_kv_last_page_lens_tensor = torch.tensor(
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device=device,
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_lens_tensor = None
block_table_bound_tensor = None
metadata.paged_kv_indptr = paged_kv_indptr_tensor
metadata.paged_kv_indices = paged_kv_indices_tensor
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
metadata.block_table_bound = block_table_bound_tensor
return metadata
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
@contextmanager
def graph_capture(self, max_batch_size: int):
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
max_batch_size=max_batch_size,
block_size=self.runner.block_size,
max_block_per_batch=self.runner.get_max_block_per_batch(),
device=self.runner.device)
self._paged_kv_indices_tensor = kv_indices
self._paged_kv_indptr_tensor = kv_indptr
self._paged_kv_last_page_lens_tensor = last_page_lens
with super().graph_capture(max_batch_size):
yield
del self._paged_kv_indices_tensor
del self._paged_kv_indptr_tensor
del self._paged_kv_last_page_lens_tensor
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
metadata = super().graph_capture_get_metadata_for_batch(
batch_size, is_encoder_decoder_model)
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
paged_kv_indices = self._paged_kv_indices_tensor
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
batch_size]
metadata.paged_kv_indptr = paged_kv_indptr
metadata.paged_kv_indices = paged_kv_indices
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
return metadata
def get_graph_input_buffers(self,
attn_metadata: AiterMLAMetadata,
is_encoder_decoder_model: bool = False):
input_buffers = super().get_graph_input_buffers(
attn_metadata, is_encoder_decoder_model)
input_buffers[
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
input_buffers[
"paged_kv_indices"] = attn_metadata.\
decode_metadata.paged_kv_indices
input_buffers[
"paged_kv_last_page_lens"] = attn_metadata.\
decode_metadata.paged_kv_last_page_lens
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata: AiterMLAMetadata,
is_encoder_decoder_model: bool = False):
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
is_encoder_decoder_model)
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
0]
input_buffers["paged_kv_indptr"].copy_(
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
input_buffers["paged_kv_last_page_lens"].copy_(
attn_metadata.decode_metadata.paged_kv_last_page_lens,
non_blocking=True)
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
softmax_scale: float, return_softmax_lse: bool,
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens)
return self._v_up_proj_and_o_proj(o)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Attention layer ROCm GPUs.""" """Attention layer ROCm GPUs."""
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
...@@ -26,7 +27,34 @@ logger = init_logger(__name__) ...@@ -26,7 +27,34 @@ logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
@cache
def is_rocm_aiter_paged_attn_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
and envs.VLLM_ROCM_USE_AITER \
@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if is_rocm_aiter_paged_attn_enabled():
# Import AITERPagedAttention only when the flag is enabled
from vllm.attention.ops.rocm_aiter_paged_attn import (
AITERPagedAttention)
return AITERPagedAttention()
return PagedAttention()
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -55,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -55,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size, paged_attn = _get_paged_attn_module()
num_kv_heads, head_size) return paged_attn.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
...@@ -64,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -64,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) paged_attn = _get_paged_attn_module()
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) paged_attn = _get_paged_attn_module()
paged_attn.copy_blocks(kv_caches, src_to_dists)
@dataclass @dataclass
...@@ -495,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -495,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
supported_head_sizes = PagedAttention.get_supported_head_sizes() self.paged_attn_module = _get_paged_attn_module()
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
)
if head_size not in supported_head_sizes: if head_size not in supported_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
...@@ -515,7 +549,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -515,7 +549,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention) triton_attention)
self.attn_func = triton_attention self.triton_attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1): if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support " logger.warning("ROCm Triton FA does not currently support "
...@@ -531,7 +565,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -531,7 +565,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend") logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
...@@ -542,9 +576,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -542,9 +576,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support " "ROCm Naive FlashAttention does not support "
"attention logits soft capping.") "attention logits soft capping.")
self.attn_func = _sdpa_attention self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend") logger.debug("Using naive (SDPA) attention in ROCmBackend")
self.aiter_kv_scales_initialized = False
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)""" """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape tokens, n_kv_heads, head_dim = x.shape
...@@ -613,6 +649,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -613,6 +649,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None
...@@ -621,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -621,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
assert value is None assert value is None
paged_attn = self.paged_attn_module
# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
and not self.aiter_kv_scales_initialized
and kv_cache.shape != torch.Size([0])):
num_blocks = kv_cache.shape[1]
block_size = kv_cache.shape[2] // (self.num_kv_heads *
self.head_size)
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
self.aiter_kv_scales_initialized = True
k_scale.fill_(layer._k_scale.item())
v_scale.fill_(layer._v_scale.item())
layer._k_scale = k_scale
layer._v_scale = v_scale
# Only update KV cache for decoder self-attention # Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention # and encoder-decoder cross-attention
if self.attn_type not in [ if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0: ] and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = paged_attn.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
if key is not None and value is not None: if key is not None and value is not None:
...@@ -634,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -634,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# cache. If kv_cache is not provided, the new key and value # cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial # tensors are not cached. This happens during the initial
# memory profiling run. # memory profiling run.
PagedAttention.write_to_paged_cache( paged_attn.write_to_paged_cache(
key, key,
value, value,
key_cache, key_cache,
...@@ -656,7 +719,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -656,7 +719,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert attn_metadata.num_encoder_tokens is not None assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens num_prefill_tokens = attn_metadata.num_encoder_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
...@@ -704,11 +766,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -704,11 +766,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype, query.dtype,
seq_lens, seq_lens,
make_attn_mask=causal_mask) # type: ignore make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func( use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
full_scales = (
layer._q_scale, layer._k_scale, layer._v_scale,
layer._prob_scale) if use_fp8_scales else None
self.triton_attn_func(
query, query,
key, key,
value, value,
None, output[:num_prefill_tokens],
query_seq_start_loc, query_seq_start_loc,
key_seq_start_loc, key_seq_start_loc,
query_max_seq_len, query_max_seq_len,
...@@ -717,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -717,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.scale, self.scale,
attn_masks[0][None] attn_masks[0][None]
if attn_masks is not None else None, if attn_masks is not None else None,
full_scales,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
...@@ -733,10 +802,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -733,10 +802,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention # sdpa math backend attention
out = self.attn_func( self.sdpa_attn_func(
query, query,
key, key,
value, value,
output[:num_prefill_tokens],
query_seq_start_loc, query_seq_start_loc,
num_prefill_tokens, num_prefill_tokens,
self.num_heads, self.num_heads,
...@@ -745,7 +815,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -745,7 +815,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
out = self.attn_func( # upstream FA does not support an output arg, copy
output[:num_prefill_tokens] = self.fa_attn_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -760,33 +831,26 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -760,33 +831,26 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
) )
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else: else:
# prefix-enabled attention - # prefix-enabled attention -
# not applicable for encoder-only models # not applicable for encoder-only models
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[: output[:num_prefill_tokens] = paged_attn.forward_prefix(
num_prefill_tokens] = PagedAttention.forward_prefix( query,
query, key,
key, value,
value, self.kv_cache_dtype,
self.kv_cache_dtype, key_cache,
key_cache, value_cache,
value_cache, prefill_meta.block_tables,
prefill_meta.block_tables, prefill_meta.query_start_loc,
prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor,
prefill_meta.seq_lens_tensor, prefill_meta.max_query_len,
prefill_meta.max_query_len, self.alibi_slopes,
self.alibi_slopes, self.sliding_window[0],
self.sliding_window[0], layer._k_scale,
layer._k_scale, layer._v_scale,
layer._v_scale, )
)
# Skip decode phase for encoder-only models # Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and ( if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY): self.attn_type != AttentionType.ENCODER_ONLY):
...@@ -819,14 +883,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -819,14 +883,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
query_start_loc = None query_start_loc = None
ops.paged_attention_rocm( ops.paged_attention_rocm(
out, output[num_prefill_tokens:],
exp_sums, exp_sums,
max_logits, max_logits,
tmp_output, tmp_output,
...@@ -850,7 +910,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -850,7 +910,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
else: else:
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
...@@ -879,7 +939,8 @@ def _sdpa_attention( ...@@ -879,7 +939,8 @@ def _sdpa_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
seq_lens: List[int], output: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
...@@ -887,9 +948,9 @@ def _sdpa_attention( ...@@ -887,9 +948,9 @@ def _sdpa_attention(
attn_masks: Optional[List[torch.Tensor]] = None, attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
start = 0 start = 0
output = torch.empty((num_tokens, num_heads, head_size), assert output.shape == (num_tokens, num_heads, head_size)
dtype=query.dtype, assert output.dtype == query.dtype
device=query.device) assert output.device == query.device
for i, seq_len in enumerate(seq_lens): for i, seq_len in enumerate(seq_lens):
end = start + seq_len end = start + seq_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment