Unverified Commit 6e435de7 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[1/n][Chunked Prefill] Refactor input query shapes (#3236)

parent 426ec4ec
......@@ -47,7 +47,7 @@ steps:
- pytest -v -s prefix_caching
- label: Samplers Test
command: pytest -v -s samplers --forked
command: pytest -v -s samplers
- label: Worker Test
command: pytest -v -s worker
......@@ -56,7 +56,7 @@ steps:
command: pytest -v -s spec_decode
- label: LoRA Test %N
command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
- label: Metrics Test
......
......@@ -13,6 +13,7 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
......@@ -20,12 +21,13 @@ def test_models(
model: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
......
......@@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1, 256)
scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
......@@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
def test_scheduler_abort_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1, 256)
scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
......@@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
block_size = 4
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256)
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
......@@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
running.append(seq_group)
# Schedule seq groups prompts.
num_tokens = block_size * num_seq_group
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running)
assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs(
)[0].get_len()
assert out.num_batched_tokens == num_tokens
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == num_seq_group
......@@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
def test_scheduler_schedule_preempt_abort():
block_size = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, 2, max_model_len, 256)
scheduler_config = SchedulerConfig(64, 2, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 2
cache_config.num_gpu_blocks = 2
......@@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts.
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 2
......@@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
scheduler.abort_seq_group("1")
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_b]
assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len()
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1
......@@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
num_seq_group = 4
max_seq_group = 2
max_model_len = 16
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256)
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
......
......@@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
revision=None,
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256),
scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"),
local_rank=0,
rank=0,
......
......@@ -92,8 +92,8 @@ def test_same_output_for_single_step():
num_gpu_blocks,
seed,
)
multi_step_worker.model_runner = worker.model_runner
multi_step_worker.cache_engine = worker.cache_engine
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1
......
import random
import torch
from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
def get_aligned_size(batch_size: int, alignment: int):
return ((batch_size + alignment - 1) // alignment * alignment)
def test_prepare_prompt():
......@@ -12,6 +17,7 @@ def test_prepare_prompt():
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
......@@ -23,26 +29,165 @@ def test_prepare_prompt():
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
block_tables=block_tables,
))
expected_selected_token_indices = []
selected_token_start_idx = 0
max_seq_len = max(prompt_lens)
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
selected_token_start_idx += prompt_len
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is True
assert torch.allclose(input_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert input_metadata.prompt_lens == prompt_lens
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
assert input_metadata.num_generation_tokens == 0
assert input_metadata.max_seq_len == max(prompt_lens)
# Test subquery start locs.
start_idx = 0
start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
start_loc.append(start_idx)
assert torch.allclose(
input_metadata.subquery_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
seq_start_loc.append(start_idx)
assert torch.allclose(
input_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
assert input_metadata.max_context_len is None
assert torch.allclose(
input_metadata.context_lens,
torch.zeros(input_metadata.context_lens.shape[0],
dtype=torch.int,
device=device))
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.allclose(input_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is False
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
def test_prepare_decode_cuda_graph():
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))
input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
assert input_metadata.max_context_len == max(prompt_lens)
assert torch.allclose(
input_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert input_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert input_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
torch.testing.assert_close(input_tokens, input_positions)
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
......
......@@ -535,7 +535,6 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
def __init__(
......@@ -543,7 +542,6 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
......@@ -553,7 +551,6 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args()
def _verify_args(self) -> None:
......
......@@ -173,12 +173,12 @@ class Scheduler:
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences = deque()
num_batched_tokens = 0
while self.waiting:
seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs(
......@@ -223,8 +223,7 @@ class Scheduler:
continue
# If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
num_batched_tokens += num_prompt_tokens
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break
......@@ -236,11 +235,6 @@ class Scheduler:
self.scheduler_config.max_num_seqs):
break
num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens
if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.waiting.popleft()
......@@ -255,8 +249,7 @@ class Scheduler:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=len(seq_lens) *
max(seq_lens) if seq_lens else 0,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
......
......@@ -31,7 +31,6 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False
revision: Optional[str] = None
......@@ -213,10 +212,6 @@ class EngineArgs:
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument(
'--max-logprobs',
type=int,
......@@ -347,8 +342,7 @@ class EngineArgs:
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
model_config.max_model_len)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
......
......@@ -561,7 +561,6 @@ class LLMEngine:
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs
def step(self) -> List[RequestOutput]:
......
from dataclasses import dataclass, fields
from typing import Optional, Any, Dict
from typing import Optional, List, Any, Dict
import torch
from xformers.ops.fmha.attn_bias import AttentionBias
@dataclass
class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention.
Args:
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
prompt_lens: Optional[torch.Tensor]
max_seq_len: Optional[int]
start_loc: Optional[torch.Tensor]
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
"""
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum context length in the batch.
max_context_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph: bool
kv_cache_dtype: str
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias = None
self.attn_bias: Optional[List[AttentionBias]] = None
# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
assert self.num_prompt_tokens == 0
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
......
......@@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
......
......@@ -17,11 +17,12 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
3. Output the output tensor.
"""
def __init__(
......
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional
from flash_attn import flash_attn_func
from flash_attn import flash_attn_varlen_func
import torch
from vllm.model_executor.input_metadata import InputMetadata
......@@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
class FlashAttentionBackend:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
......@@ -52,18 +67,18 @@ class FlashAttentionBackend:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
......@@ -82,13 +97,16 @@ class FlashAttentionBackend:
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
......@@ -118,4 +136,4 @@ class FlashAttentionBackend:
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
return output.view(num_tokens, hidden_size)
......@@ -14,6 +14,21 @@ from vllm.utils import is_hip
class XFormersBackend:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
......@@ -55,19 +70,18 @@ class XFormersBackend:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
......@@ -82,9 +96,10 @@ class XFormersBackend:
if input_metadata.is_prompt:
# Prompt run.
# key_cache and value_cache are None when it is a profiling run.
# block tables are empty if the prompt has never been computed.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
......@@ -103,61 +118,33 @@ class XFormersBackend:
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
if self.use_ref_attention:
output = _ref_masked_attention(
query,
key,
value,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.scale,
)
print("ref attention used.")
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = _ref_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
self.num_heads,
self.num_kv_heads,
self.head_size,
self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len
# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(batch_size, seq_len, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
return output.reshape(num_tokens, hidden_size)
output = self._run_memory_efficient_xformer_forward(
query, key, value, input_metadata)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
......@@ -182,41 +169,117 @@ class XFormersBackend:
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformer_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
input_metadata.prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = [attn_bias]
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
input_metadata)
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
is_hip()) else None
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=op)
return out.view_as(query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query)
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=op)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
input_metadata: InputMetadata,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
batch_size,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias
attn_biases = []
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (prompt_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
prompt_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_biases
def _check_use_ref_attention() -> bool:
......@@ -239,7 +302,6 @@ def _ref_masked_attention(
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
......
......@@ -128,11 +128,12 @@ class PagedAttentionImpl:
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.block_tables,
# subquery_start_loc is (batch_size + 1,)
input_metadata.subquery_start_loc[:-1],
input_metadata.prompt_lens_tensor,
input_metadata.context_lens,
input_metadata.max_seq_len,
input_metadata.max_subquery_len,
alibi_slopes,
)
return output
......@@ -128,7 +128,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment