Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import pytest
from vllm import LLM, SamplingParams
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
},
max_model_len=100,
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import numpy as np import numpy as np
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp, _find_subarray_kmp,
_kmp_lps_array) _kmp_lps_array)
...@@ -39,50 +40,50 @@ def test_find_subarray_kmp(): ...@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
def test_ngram_proposer(): def test_ngram_proposer():
proposer = NgramProposer()
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m",
task="generate",
max_model_len=100,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
trust_remote_code=False)
return NgramProposer(
vllm_config=VllmConfig(model_config=model_config,
speculative_config=SpeculativeConfig.
from_dict({
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match. # No match.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 5]), 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
min_n=2,
max_n=2,
k=2,
)
assert result is None assert result is None
# No match for 4-gram. # No match for 4-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=4,
max_n=4,
k=2,
)
assert result is None assert result is None
# No match for 4-gram but match for 3-gram. # No match for 4-gram but match for 3-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([4, 1])) assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram. # Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match. # In this case, the proposer should return the 4-gram match.
result = proposer.propose( result = ngram_proposer(3, 4, 2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
min_n=3,
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram. # Match for 2-gram and 3-gram, but not 4-gram.
result = proposer.propose( result = ngram_proposer(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), 2, 4,
min_n=2, 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
max_n=4,
k=2,
)
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
...@@ -2,17 +2,13 @@ ...@@ -2,17 +2,13 @@
import pytest import pytest
from vllm.v1.structured_output.utils import ( from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features) has_xgrammar_unsupported_json_features)
@pytest.fixture @pytest.fixture
def unsupported_string_schemas(): def unsupported_string_schemas():
return [ return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{ {
"type": "string", "type": "string",
"format": "email" "format": "email"
...@@ -23,22 +19,6 @@ def unsupported_string_schemas(): ...@@ -23,22 +19,6 @@ def unsupported_string_schemas():
@pytest.fixture @pytest.fixture
def unsupported_integer_schemas(): def unsupported_integer_schemas():
return [ return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{ {
"type": "integer", "type": "integer",
"multipleOf": 120 "multipleOf": 120
...@@ -49,22 +29,6 @@ def unsupported_integer_schemas(): ...@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@pytest.fixture @pytest.fixture
def unsupported_number_schemas(): def unsupported_number_schemas():
return [ return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{ {
"type": "number", "type": "number",
"multipleOf": 120 "multipleOf": 120
...@@ -156,13 +120,28 @@ def supported_schema(): ...@@ -156,13 +120,28 @@ def supported_schema():
"type": "string", "type": "string",
"enum": ["sedan", "suv", "truck"] "enum": ["sedan", "suv", "truck"]
}, },
"car_brand": {
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
"short_description": { "short_description": {
"type": "string", "type": "string",
"maxLength": 50 "maxLength": 50
}, },
"mileage": {
"type": "number",
"minimum": 0,
"maximum": 1000000
},
"model_year": {
"type": "integer",
"exclusiveMinimum": 1900,
"exclusiveMaximum": 2100
},
"long_description": { "long_description": {
"type": "string", "type": "string",
"minLength": 50 "minLength": 50,
"maxLength": 2000
}, },
"address": { "address": {
"type": "object", "type": "object",
......
...@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind): ...@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so # the engines only synchronize stopping every N steps so
# allow a small amount of time here. # allow a small amount of time here.
for _ in range(10): for _ in range(10):
if core_client.num_engines_running == 0: if not core_client.engines_running:
break break
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
assert core_client.num_engines_running == 0 assert not core_client.engines_running
assert not core_client.reqs_in_flight assert not core_client.reqs_in_flight
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import msgspec
import numpy as np import numpy as np
import torch import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
...@@ -26,6 +32,7 @@ class MyType: ...@@ -26,6 +32,7 @@ class MyType:
large_f_contig_tensor: torch.Tensor large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor large_non_contig_tensor: torch.Tensor
empty_tensor: torch.Tensor
def test_encode_decode(): def test_encode_decode():
...@@ -41,6 +48,10 @@ def test_encode_decode(): ...@@ -41,6 +48,10 @@ def test_encode_decode():
torch.rand((1, 10), dtype=torch.float32), torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64), torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
], ],
numpy_array=np.arange(512), numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33), unrecognized=UnrecognizedType(33),
...@@ -48,9 +59,10 @@ def test_encode_decode(): ...@@ -48,9 +59,10 @@ def test_encode_decode():
large_f_contig_tensor=torch.rand(1024, 4).t(), large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
) )
encoder = MsgpackEncoder() encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType) decoder = MsgpackDecoder(MyType)
encoded = encoder.encode(obj) encoded = encoder.encode(obj)
...@@ -58,7 +70,7 @@ def test_encode_decode(): ...@@ -58,7 +70,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers # There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes. # + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline. # The two small tensors are encoded inline.
assert len(encoded) == 6 assert len(encoded) == 8
decoded: MyType = decoder.decode(encoded) decoded: MyType = decoder.decode(encoded)
...@@ -70,7 +82,7 @@ def test_encode_decode(): ...@@ -70,7 +82,7 @@ def test_encode_decode():
encoded2 = encoder.encode_into(obj, preallocated) encoded2 = encoder.encode_into(obj, preallocated)
assert len(encoded2) == 6 assert len(encoded2) == 8
assert encoded2[0] is preallocated assert encoded2[0] is preallocated
decoded2: MyType = decoder.decode(encoded2) decoded2: MyType = decoder.decode(encoded2)
...@@ -78,6 +90,97 @@ def test_encode_decode(): ...@@ -78,6 +90,97 @@ def test_encode_decode():
assert_equal(decoded2, obj) assert_equal(decoded2, obj)
class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]
def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(20000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 6
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalSharedField(4))
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 8
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert total_len >= 14235 and total_len <= 14275
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all(nested_equal(x, y) for x, y in zip(a, b))
def assert_equal(obj1: MyType, obj2: MyType): def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1) assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string assert obj1.a_string == obj2.a_string
...@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType): ...@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2.small_non_contig_tensor) obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor, assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor) obj2.large_non_contig_tensor)
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
...@@ -22,6 +22,7 @@ MODELS = [ ...@@ -22,6 +22,7 @@ MODELS = [
] ]
TENSOR_PARALLEL_SIZES = [1] TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance # TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4] # TENSOR_PARALLEL_SIZES = [1, 4]
...@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1] ...@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic( def test_basic(
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
tensor_parallel_size: int, tensor_parallel_size: int,
max_num_seqs: int,
) -> None: ) -> None:
prompt = "The next numbers of the sequence " + ", ".join( prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:" str(i) for i in range(1024)) + " are:"
...@@ -51,9 +54,9 @@ def test_basic( ...@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to # Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt # actually test chunked prompt
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
max_model_len=8196, max_model_len=8192,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
max_num_seqs=16, max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model: tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
......
# SPDX-License-Identifier: Apache-2.0
import openai
import pytest
from vllm import envs
from vllm.multimodal.utils import encode_image_base64, fetch_image
from vllm.platforms import current_platform
from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS
from ...utils import RemoteOpenAIServer
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.fixture(scope="session")
def base64_encoded_image() -> dict[str, str]:
return {
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}
@pytest.mark.asyncio
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str,
str]):
def whats_in_this_image_msg(b64):
return [{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}]
server_args = [
"--max-model-len",
"1024",
"--max-num-seqs",
"16",
"--gpu-memory-utilization",
"0.95",
"--trust-remote-code",
"--max-num-batched-tokens",
"576",
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input",
"--chat-template",
"examples/template_llava.jinja"
]
# Server will pre-compile on first startup (takes a long time).
with RemoteOpenAIServer(model_name, server_args,
max_wait_seconds=600) as remote_server:
client: openai.AsyncOpenAI = remote_server.get_async_client()
# Other requests now should be much faster
for image_url in TEST_IMAGE_URLS:
image_base64 = base64_encoded_image[image_url]
chat_completion_from_base64 = await client.chat.completions\
.create(
model=model_name,
messages=whats_in_this_image_msg(image_base64),
max_completion_tokens=24,
temperature=0.0)
result = chat_completion_from_base64
assert result
choice = result.choices[0]
assert choice.finish_reason == "length"
message = choice.message
message = result.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random
import pytest import pytest
from vllm import LLM, envs from vllm import LLM, envs
...@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str): ...@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param. # Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42) sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params) output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK/P
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12),
top_p=random.random()) for _ in range(B)
]
# Make sure first two reqs have the same K/P
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
...@@ -5,7 +5,8 @@ import pytest ...@@ -5,7 +5,8 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)
if not current_platform.is_tpu(): if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True) pytest.skip("This test needs a TPU.", allow_module_level=True)
...@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024 ...@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6 TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p(): def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()): with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33) xm.set_rng_state(seed=33)
......
...@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[], mm_inputs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
...@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
def test_get_paddings(): def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64 min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 256, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap) padding_gap)
assert actual_paddings == expected_paddings assert actual_paddings == expected_paddings
......
...@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return CachedRequestState( return CachedRequestState(
req_id=f"req_id_{req_id_suffix}", req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
mm_inputs=[], mm_inputs=[],
mm_positions=[], mm_positions=[],
......
...@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[], mm_inputs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
......
...@@ -1616,6 +1616,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, ...@@ -1616,6 +1616,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states, pad_slot_id) ssm_states, pad_slot_id)
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor,
rows_per_block: int) -> torch.Tensor:
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
return out
# moe # moe
def moe_sum(input: torch.Tensor, output: torch.Tensor): def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output) torch.ops._moe_C.moe_sum(input, output)
...@@ -1665,6 +1685,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, ...@@ -1665,6 +1685,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output) token_expert_indicies, gating_output)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor, moe_block_size: int,
top_k: int, mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe") @register_fake("_moe_C::marlin_gemm_moe")
...@@ -1683,6 +1726,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): ...@@ -1683,6 +1726,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype=a.dtype, dtype=a.dtype,
device=a.device) device=a.device)
@register_fake("_moe_C::moe_wna16_marlin_gemm")
def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor,
moe_block_size: int, top_k: int,
mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int, is_k_full: bool,
use_atomic_add: bool, use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.empty((size_m * top_k, size_n),
dtype=input.dtype,
device=input.device)
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
...@@ -1904,3 +1970,12 @@ def flash_mla_with_kvcache( ...@@ -1904,3 +1970,12 @@ def flash_mla_with_kvcache(
num_splits, num_splits,
) )
return out, softmax_lse return out, softmax_lse
# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
# seq_lens: torch.Tensor, page_table: torch.Tensor,
# scale: float) -> torch.Tensor:
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# return out
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Literal from typing import Literal, Optional
import cv2 import cv2
import numpy as np import numpy as np
...@@ -10,8 +10,15 @@ import numpy.typing as npt ...@@ -10,8 +10,15 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from vllm.utils import PlaceholderModule
from .base import get_cache_dir from .base import get_cache_dir
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
@lru_cache @lru_cache
def download_video_asset(filename: str) -> str: def download_video_asset(filename: str) -> str:
...@@ -85,3 +92,12 @@ class VideoAsset: ...@@ -85,3 +92,12 @@ class VideoAsset:
video_path = download_video_asset(self.name) video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames) ret = video_to_ndarrays(video_path, self.num_frames)
return ret return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.name)
return librosa.load(video_path, sr=sampling_rate)[0]
...@@ -77,6 +77,10 @@ class AttentionBackend(ABC): ...@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def swap_blocks( def swap_blocks(
...@@ -237,6 +241,7 @@ class AttentionLayer(Protocol): ...@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
_v_scale: torch.Tensor _v_scale: torch.Tensor
_k_scale_float: float _k_scale_float: float
_v_scale_float: float _v_scale_float: float
_prob_scale: torch.Tensor
def forward( def forward(
self, self,
......
...@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import ( ...@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache) flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
...@@ -691,7 +691,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -691,7 +691,7 @@ class FlashAttentionImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert ( assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with " "key/v_scale is only supported in FlashAttention 3 with "
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import dataclasses import dataclasses
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
...@@ -48,6 +49,9 @@ if TYPE_CHECKING: ...@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend): ...@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
...@@ -128,12 +140,10 @@ def get_per_layer_parameters( ...@@ -128,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`. to use during `plan`.
""" """
layers = vllm_config.compilation_config.static_forward_context layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {} per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl impl = layer.impl
assert isinstance(impl, FlashInferImpl) assert isinstance(impl, FlashInferImpl)
...@@ -187,7 +197,8 @@ class FlashInferState(AttentionState): ...@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config() self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
...@@ -197,10 +208,15 @@ class FlashInferState(AttentionState): ...@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
device=self.runner.device) device=self.runner.device)
return self._workspace_buffer return self._workspace_buffer
def get_kv_cache_layout(self):
if self._kv_cache_layout is None:
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
return self._kv_cache_layout
def _get_prefill_wrapper(self): def _get_prefill_wrapper(self):
if self._prefill_wrapper is None: if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD") self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper return self._prefill_wrapper
def _get_decode_wrapper(self): def _get_decode_wrapper(self):
...@@ -213,7 +229,7 @@ class FlashInferState(AttentionState): ...@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
num_qo_heads // num_kv_heads > 4) num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
"NHD", self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores) use_tensor_cores=use_tensor_cores)
return self._decode_wrapper return self._decode_wrapper
...@@ -274,7 +290,8 @@ class FlashInferState(AttentionState): ...@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
self._graph_decode_wrapper = \ self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD", self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores) use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"): if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
...@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config() self.vllm_config = self.runner.vllm_config
def prepare(self): def prepare(self):
self.slot_mapping: List[int] = [] self.slot_mapping: List[int] = []
...@@ -1007,6 +1024,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1007,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output: Optional[torch.Tensor] = None prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill # We will use flash attention for prefill
# when kv_cache is not provided. # when kv_cache is not provided.
...@@ -1038,7 +1056,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1038,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output = prefill_meta.prefill_wrapper.run( prefill_output = prefill_meta.prefill_wrapper.run(
query, query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
) )
...@@ -1053,7 +1071,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1053,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output = decode_meta.decode_wrapper.run( decode_output = decode_meta.decode_wrapper.run(
decode_query, decode_query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
) )
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
############################################################################### ###############################################################################
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, from vllm_hpu_extension.flags import enabled_flags
VLLMKVCache) from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionLayer,
...@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.block2batch_matmul = Matmul() self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache() self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache() self.v_cache = VLLMKVCache()
ops.pa_impl = ops.pa self.fused_scaled_dot_product_attention = kernels.fsdpa()
self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
self.prefill_impl = 'flex'
if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
...@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', if self.prefill_impl == 'fsdpa':
'0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \ assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!' 'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes: if head_size not in supported_head_sizes:
...@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.") f"Supported head sizes are: {supported_head_sizes}.")
if attn_type != AttentionType.DECODER: self.attn_type = attn_type
if self.attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
...@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape _, seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt: key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1)) key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None: if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache = HPUPagedAttention.split_kv_cache( key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
...@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
query_shape = (batch_size, seq_len, self.num_heads, self.head_size) query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size) self.head_size)
attn_bias = attn_metadata.attn_bias
if attn_bias is not None and self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
out = ops.prompt_attention( out = ops.prompt_attention(
query.view(query_shape), impl=self.prefill_impl,
key.view(kv_shape), query=query.view(query_shape),
value.view(kv_shape), key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, valid_seq_lengths=attn_metadata.seq_lens_tensor,
scale=self.scale, **self.common_attention_args())
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
output = out.reshape(batch_size, seq_len, hidden_size) output = out.reshape(batch_size, seq_len, hidden_size)
else: else:
# Decoding run. # Decoding run.
...@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list=attn_metadata.block_list, block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping, block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias, block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups, block_groups=attn_metadata.block_groups,
scale=self.scale, **self.common_attention_args())
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
}
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
......
...@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache, value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale_float,
layer._v_scale, layer._v_scale_float,
) )
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
...@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale_float,
layer._v_scale, layer._v_scale_float,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale_float,
layer._v_scale, layer._v_scale_float,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
...@@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap ...@@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON: if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.triton_flash_attention import triton_attention
...@@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata): ...@@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
self._ops_advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions)
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops.advance_step_flashattn(num_seqs=num_seqs, ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries, num_queries=num_queries,
block_size=block_size, block_size=block_size,
input_tokens=model_input.input_tokens, input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions, input_positions=input_positions,
seq_lens=self.seq_lens_tensor, seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping,
block_tables=self.block_tables) block_tables=self.block_tables)
...@@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): ...@@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
BLOCK_TABLE_EXTENDER: list[list[int]] = []
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder self.input_builder = input_builder
...@@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): ...@@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables( block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables) num_seqs, self.block_tables)
else: else:
...@@ -1043,8 +1058,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1043,8 +1058,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.q_proj = q_proj self.q_proj = q_proj
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.triton_fa_func = triton_attention
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3 # latter has an additional parameter to control FA2 vs FA3
...@@ -1057,6 +1072,82 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1057,6 +1072,82 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9
and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 )
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
maybe_padded_v = v
if self._pad_v:
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]] - 32, value=0)
v_tmp = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse:
attn_out = self.triton_fa_func(
q,
k,
maybe_padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
# v=maybe_padded_v,
v = v_tmp,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest = None
if isinstance(attn_out, tuple):
attn_out, *rest = attn_out
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
assert rest is not None
return attn_out, rest[0]
return attn_out
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
...@@ -1181,40 +1272,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1181,40 +1272,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad attn_output, attn_softmax_lse = \
# out v with 0s to match the qk head dim self._flash_attn_varlen_diff_headdims(
v_padded = torch.nn.functional.pad(v, q=q,
[0, q.shape[-1] - v.shape[-1]], k=k,
value=0) v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
if is_vllm_fa: cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
attn_output, attn_softmax_lse = self.flash_attn_varlen_func( max_seqlen_q=prefill_metadata.max_query_len,
q=q, max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
k=k, softmax_scale=self.scale,
v=v_padded, causal=False, # Context is unmasked
cu_seqlens_q=prefill_metadata.query_start_loc, return_softmax_lse=True,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], )
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
if output is None: if output is None:
output = attn_output output = attn_output
...@@ -1257,61 +1327,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1257,61 +1327,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out output = self._flash_attn_varlen_diff_headdims(
# v with 0s to match the qk head dim q=q,
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], k=k,
# value=0) v=v,
v_padded = torch.nn.functional.pad(v, [0, (q.shape[-1] - v.shape[-1] -32)], cu_seqlens_q=prefill_metadata.query_start_loc,
value=0) cu_seqlens_k=prefill_metadata.query_start_loc,
v_tmp = v_padded[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2]) max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: softmax_scale=self.scale,
output = self.triton_fa_func( causal=True,
q, return_softmax_lse=has_context,
k, )
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_tmp if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 else v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
if has_context: if has_context:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2 # ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata)
...@@ -1324,14 +1355,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1324,14 +1355,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse, suffix_lse=suffix_lse,
) )
# slice by `:v.shape[-1]` in order to remove v headdim padding return self.o_proj(output.flatten(start_dim=-2))[0]
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
output = output\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment