Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import pytest
from vllm import LLM, SamplingParams
_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
"Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501
]
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
},
max_model_len=100,
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
......@@ -2,6 +2,7 @@
import numpy as np
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
_find_subarray_kmp,
_kmp_lps_array)
......@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
def test_ngram_proposer():
proposer = NgramProposer()
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m",
task="generate",
max_model_len=100,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=None,
trust_remote_code=False)
return NgramProposer(
vllm_config=VllmConfig(model_config=model_config,
speculative_config=SpeculativeConfig.
from_dict({
"prompt_lookup_min": min_n,
"prompt_lookup_max": max_n,
"num_speculative_tokens": k,
"method": "ngram",
})))
# No match.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 5]),
min_n=2,
max_n=2,
k=2,
)
result = ngram_proposer(
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
assert result is None
# No match for 4-gram.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
min_n=4,
max_n=4,
k=2,
)
result = ngram_proposer(
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert result is None
# No match for 4-gram but match for 3-gram.
result = proposer.propose(
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
min_n=3,
max_n=4,
k=2,
)
result = ngram_proposer(
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
assert np.array_equal(result, np.array([4, 1]))
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
result = proposer.propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
min_n=3,
max_n=4,
k=2,
)
result = ngram_proposer(3, 4, 2).propose(
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
result = proposer.propose(
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
min_n=2,
max_n=4,
k=2,
)
result = ngram_proposer(
2, 4,
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]
......@@ -2,17 +2,13 @@
import pytest
from vllm.v1.structured_output.utils import (
from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features)
@pytest.fixture
def unsupported_string_schemas():
return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{
"type": "string",
"format": "email"
......@@ -23,22 +19,6 @@ def unsupported_string_schemas():
@pytest.fixture
def unsupported_integer_schemas():
return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{
"type": "integer",
"multipleOf": 120
......@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@pytest.fixture
def unsupported_number_schemas():
return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{
"type": "number",
"multipleOf": 120
......@@ -156,13 +120,28 @@ def supported_schema():
"type": "string",
"enum": ["sedan", "suv", "truck"]
},
"car_brand": {
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
"short_description": {
"type": "string",
"maxLength": 50
},
"mileage": {
"type": "number",
"minimum": 0,
"maximum": 1000000
},
"model_year": {
"type": "integer",
"exclusiveMinimum": 1900,
"exclusiveMaximum": 2100
},
"long_description": {
"type": "string",
"minLength": 50
"minLength": 50,
"maxLength": 2000
},
"address": {
"type": "object",
......
......@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for _ in range(10):
if core_client.num_engines_running == 0:
if not core_client.engines_running:
break
await asyncio.sleep(0.5)
assert core_client.num_engines_running == 0
assert not core_client.engines_running
assert not core_client.reqs_in_flight
# SPDX-License-Identifier: Apache-2.0
from collections import UserDict
from dataclasses import dataclass
from typing import Optional
import msgspec
import numpy as np
import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
......@@ -26,6 +32,7 @@ class MyType:
large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor
empty_tensor: torch.Tensor
def test_encode_decode():
......@@ -41,6 +48,10 @@ def test_encode_decode():
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
......@@ -48,9 +59,10 @@ def test_encode_decode():
large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
)
encoder = MsgpackEncoder()
encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType)
encoded = encoder.encode(obj)
......@@ -58,7 +70,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 6
assert len(encoded) == 8
decoded: MyType = decoder.decode(encoded)
......@@ -70,7 +82,7 @@ def test_encode_decode():
encoded2 = encoder.encode_into(obj, preallocated)
assert len(encoded2) == 6
assert len(encoded2) == 8
assert encoded2[0] is preallocated
decoded2: MyType = decoder.decode(encoded2)
......@@ -78,6 +90,97 @@ def test_encode_decode():
assert_equal(decoded2, obj)
class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]
def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(20000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 6
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalSharedField(4))
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 8
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert total_len >= 14235 and total_len <= 14275
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all(nested_equal(x, y) for x, y in zip(a, b))
def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string
......@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor)
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
......@@ -22,6 +22,7 @@ MODELS = [
]
TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
......@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str,
max_tokens: int,
tensor_parallel_size: int,
max_num_seqs: int,
) -> None:
prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
......@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens=1024,
max_model_len=8196,
max_model_len=8192,
gpu_memory_utilization=0.7,
max_num_seqs=16,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
......
# SPDX-License-Identifier: Apache-2.0
import openai
import pytest
from vllm import envs
from vllm.multimodal.utils import encode_image_base64, fetch_image
from vllm.platforms import current_platform
from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS
from ...utils import RemoteOpenAIServer
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.fixture(scope="session")
def base64_encoded_image() -> dict[str, str]:
return {
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}
@pytest.mark.asyncio
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str,
str]):
def whats_in_this_image_msg(b64):
return [{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}]
server_args = [
"--max-model-len",
"1024",
"--max-num-seqs",
"16",
"--gpu-memory-utilization",
"0.95",
"--trust-remote-code",
"--max-num-batched-tokens",
"576",
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input",
"--chat-template",
"examples/template_llava.jinja"
]
# Server will pre-compile on first startup (takes a long time).
with RemoteOpenAIServer(model_name, server_args,
max_wait_seconds=600) as remote_server:
client: openai.AsyncOpenAI = remote_server.get_async_client()
# Other requests now should be much faster
for image_url in TEST_IMAGE_URLS:
image_base64 = base64_encoded_image[image_url]
chat_completion_from_base64 = await client.chat.completions\
.create(
model=model_name,
messages=whats_in_this_image_msg(image_base64),
max_completion_tokens=24,
temperature=0.0)
result = chat_completion_from_base64
assert result
choice = result.choices[0]
assert choice.finish_reason == "length"
message = choice.message
message = result.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
# SPDX-License-Identifier: Apache-2.0
import random
import pytest
from vllm import LLM, envs
......@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK/P
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12),
top_p=random.random()) for _ in range(B)
]
# Make sure first two reqs have the same K/P
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
......@@ -5,7 +5,8 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
apply_top_k_top_p_tpu)
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
......@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE, ))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
......
......@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
......@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 256, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
......
......@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
......
......@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
......
......@@ -1616,6 +1616,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states, pad_slot_id)
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor,
rows_per_block: int) -> torch.Tensor:
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
return out
# moe
def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output)
......@@ -1665,6 +1685,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies, gating_output)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor, b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor, moe_block_size: int,
top_k: int, mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
is_zp_float)
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@register_fake("_moe_C::marlin_gemm_moe")
......@@ -1683,6 +1726,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype=a.dtype,
device=a.device)
@register_fake("_moe_C::moe_wna16_marlin_gemm")
def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_past_padded: torch.Tensor,
topk_weights: torch.Tensor,
moe_block_size: int, top_k: int,
mul_topk_weights: bool, is_ep: bool,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int, is_k_full: bool,
use_atomic_add: bool, use_fp32_reduce: bool,
is_zp_float: bool) -> torch.Tensor:
return torch.empty((size_m * top_k, size_n),
dtype=input.dtype,
device=input.device)
def reshape_and_cache(
key: torch.Tensor,
......@@ -1904,3 +1970,12 @@ def flash_mla_with_kvcache(
num_splits,
)
return out, softmax_lse
# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
# seq_lens: torch.Tensor, page_table: torch.Tensor,
# scale: float) -> torch.Tensor:
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# return out
......@@ -2,7 +2,7 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
from typing import Literal, Optional
import cv2
import numpy as np
......@@ -10,8 +10,15 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.utils import PlaceholderModule
from .base import get_cache_dir
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
@lru_cache
def download_video_asset(filename: str) -> str:
......@@ -85,3 +92,12 @@ class VideoAsset:
video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.name)
return librosa.load(video_path, sr=sampling_rate)[0]
......@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
......@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
def forward(
self,
......
......@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
......@@ -691,7 +691,7 @@ class FlashAttentionImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "
......
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
......@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
class FlashInferBackend(AttentionBackend):
......@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
......@@ -128,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers = vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
......@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
......@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
device=self.runner.device)
return self._workspace_buffer
def get_kv_cache_layout(self):
if self._kv_cache_layout is None:
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
return self._kv_cache_layout
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper
def _get_decode_wrapper(self):
......@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
......@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
......@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
self.vllm_config = self.runner.vllm_config
def prepare(self):
self.slot_mapping: List[int] = []
......@@ -1007,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
......@@ -1038,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
......@@ -1053,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
......
......@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
......@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
ops.pa_impl = ops.pa
self.fused_scaled_dot_product_attention = kernels.fsdpa()
self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
self.prefill_impl = 'flex'
if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
......@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa:
if self.prefill_impl == 'fsdpa':
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
......@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if attn_type != AttentionType.DECODER:
self.attn_type = attn_type
if self.attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
......@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
......@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if attn_metadata.is_prompt:
# Prompt run.
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
attn_bias = attn_metadata.attn_bias
if attn_bias is not None and self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
......@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
**self.common_attention_args())
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
}
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
......
......@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
layer._k_scale_float,
layer._v_scale_float,
)
if attn_metadata.is_prompt:
......@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
layer._k_scale_float,
layer._v_scale_float,
)
else:
# Run PagedAttention V2.
......@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
layer._k_scale_float,
layer._v_scale_float,
)
# Reshape the output tensor.
......
......@@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
......@@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
......@@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
self._ops_advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions)
def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
input_positions=input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
......@@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
BLOCK_TABLE_EXTENDER: list[list[int]] = []
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
......@@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
......@@ -1043,8 +1058,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.triton_fa_func = triton_attention
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
......@@ -1057,6 +1072,82 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9
and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 )
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
maybe_padded_v = v
if self._pad_v:
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]] - 32, value=0)
v_tmp = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse:
attn_out = self.triton_fa_func(
q,
k,
maybe_padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
# v=maybe_padded_v,
v = v_tmp,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest = None
if isinstance(attn_out, tuple):
attn_out, *rest = attn_out
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
assert rest is not None
return attn_out, rest[0]
return attn_out
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
......@@ -1181,40 +1272,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v,
[0, q.shape[-1] - v.shape[-1]],
value=0)
if is_vllm_fa:
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if output is None:
output = attn_output
......@@ -1257,61 +1327,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
v_padded = torch.nn.functional.pad(v, [0, (q.shape[-1] - v.shape[-1] -32)],
value=0)
v_tmp = v_padded[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
output = self.triton_fa_func(
q,
k,
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_tmp if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 else v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
if has_context:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
......@@ -1324,14 +1355,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse,
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
output = output\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
return self.o_proj(output.flatten(start_dim=-2))[0]
@abstractmethod
def _forward_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment