Unverified Commit 0794e744 authored by Elfie Guo's avatar Elfie Guo Committed by GitHub
Browse files

[Misc] Add multipstep chunked-prefill support for FlashInfer (#10467)

parent b7ee940a
...@@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel( ...@@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr, int64_t const block_tables_stride, int const* block_tables_ptr, int64_t const block_tables_stride,
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads); int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x < num_query_blocks) { if (blockIdx.x < num_query_blocks) {
......
...@@ -5,6 +5,8 @@ from typing import Optional ...@@ -5,6 +5,8 @@ from typing import Optional
import pytest import pytest
from tests.kernels.utils import override_backend_env_variable
from ..models.utils import check_logprobs_close, check_outputs_equal from ..models.utils import check_logprobs_close, check_outputs_equal
MODELS = [ MODELS = [
...@@ -19,10 +21,11 @@ NUM_PROMPTS = [10] ...@@ -19,10 +21,11 @@ NUM_PROMPTS = [10]
@pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5]) @pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm( def test_multi_step_llm(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
...@@ -36,6 +39,8 @@ def test_multi_step_llm( ...@@ -36,6 +39,8 @@ def test_multi_step_llm(
num_scheduler_steps: int, num_scheduler_steps: int,
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine. """Test vLLM engine with multi-step scheduling via sync LLM Engine.
...@@ -63,6 +68,7 @@ def test_multi_step_llm( ...@@ -63,6 +68,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned. completions endpoint; `None` -> 1 logprob returned.
""" """
override_backend_env_variable(monkeypatch, attention_backend)
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -114,6 +120,7 @@ def test_multi_step_llm( ...@@ -114,6 +120,7 @@ def test_multi_step_llm(
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) @pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
def test_multi_step_llm_w_prompt_logprobs( def test_multi_step_llm_w_prompt_logprobs(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], num_prompt_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None: ) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine. """Test prompt logprobs with multi-step scheduling via sync LLM Engine.
...@@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the note that this argument is not supported by the
OpenAI completions endpoint. OpenAI completions endpoint.
""" """
override_backend_env_variable(monkeypatch, attention_backend)
prompts = example_prompts prompts = example_prompts
if len(prompts) < num_prompts: if len(prompts) < num_prompts:
...@@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs( ...@@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5]) @pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
def test_multi_step_llm_chunked_prefill_prefix_cache( def test_multi_step_llm_chunked_prefill_prefix_cache(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps: int, num_scheduler_steps: int,
num_prompts: int, num_prompts: int,
num_logprobs: Optional[int], num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None: ) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC. """Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
...@@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( ...@@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
# #
# The Incorrect scheduling behavior - if it occurs - will cause an exception # The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`. # in the model runner resulting from `do_sample=False`.
override_backend_env_variable(monkeypatch, attention_backend)
assert len(example_prompts) >= 2 assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts) challenge_prompts = copy.deepcopy(example_prompts)
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
......
...@@ -256,7 +256,12 @@ class FlashInferState(AttentionState): ...@@ -256,7 +256,12 @@ class FlashInferState(AttentionState):
def begin_forward(self, model_input): def begin_forward(self, model_input):
assert not self._is_graph_capturing assert not self._is_graph_capturing
state = self state = self
if model_input.attn_metadata.use_cuda_graph: use_cuda_graph = model_input.attn_metadata.use_cuda_graph
is_decode = model_input.attn_metadata.num_prefills == 0
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0] batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine] state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state) [batch_size].attn_state)
...@@ -429,10 +434,24 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -429,10 +434,24 @@ class FlashInferMetadata(AttentionMetadata):
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
assert not turn_prefills_into_decodes, \ if turn_prefills_into_decodes:
("Chunked prefill is not supported with flashinfer yet." # When Multi-Step is enabled with Chunked-Prefill, prefills and
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " # decodes are scheduled together. In the first step, all the
"specific parameter.") # prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert self.decode_query_len == 1
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens_tensor is not None
assert num_seqs > 0 assert num_seqs > 0
assert num_queries > 0 assert num_queries > 0
......
...@@ -5,6 +5,7 @@ import itertools ...@@ -5,6 +5,7 @@ import itertools
import time import time
import warnings import warnings
import weakref import weakref
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple, Type, TypeVar, Union) Tuple, Type, TypeVar, Union)
...@@ -1028,6 +1029,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1028,6 +1029,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.has_inner_state = model_config.has_inner_state self.has_inner_state = model_config.has_inner_state
self.in_profile_run = False
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in # max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table # Python can be expensive. To optimize this, we cache the block table
...@@ -1228,110 +1231,123 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1228,110 +1231,123 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return builder.build() # type: ignore return builder.build() # type: ignore
@contextmanager
def set_in_profile_run(self):
self.in_profile_run = True
try:
yield
finally:
self.in_profile_run = False
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage. with self.set_in_profile_run():
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) # Enable top-k sampling to reflect the accurate memory usage.
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens sampling_params = \
max_num_seqs = self.scheduler_config.max_num_seqs SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
# This represents the maximum number of different requests max_num_batched_tokens = \
# that will have unique loras, an therefore the max amount of memory self.scheduler_config.max_num_batched_tokens
# consumption create dummy lora request copies from the lora request max_num_seqs = self.scheduler_config.max_num_seqs
# passed in, which contains a lora from the lora warmup path. # This represents the maximum number of different requests
dummy_lora_requests: List[LoRARequest] = [] # that will have unique loras, an therefore the max amount of memory
dummy_lora_requests_per_seq: List[LoRARequest] = [] # consumption create dummy lora request copies from the lora request
if self.lora_config: # passed in, which contains a lora from the lora warmup path.
assert self.lora_manager is not None dummy_lora_requests: List[LoRARequest] = []
with self.lora_manager.dummy_lora_cache(): dummy_lora_requests_per_seq: List[LoRARequest] = []
for idx in range(self.lora_config.max_loras): if self.lora_config:
lora_id = idx + 1 assert self.lora_manager is not None
dummy_lora_request = LoRARequest( with self.lora_manager.dummy_lora_cache():
lora_name=f"warmup_{lora_id}", for idx in range(self.lora_config.max_loras):
lora_int_id=lora_id, lora_id = idx + 1
lora_path="/not/a/real/path", dummy_lora_request = LoRARequest(
) lora_name=f"warmup_{lora_id}",
self.lora_manager.add_dummy_lora(dummy_lora_request, lora_int_id=lora_id,
rank=LORA_WARMUP_RANK) lora_path="/not/a/real/path",
dummy_lora_requests.append(dummy_lora_request) )
dummy_lora_requests_per_seq = [ self.lora_manager.add_dummy_lora(dummy_lora_request,
dummy_lora_requests[idx % len(dummy_lora_requests)] rank=LORA_WARMUP_RANK)
for idx in range(max_num_seqs) dummy_lora_requests.append(dummy_lora_request)
] dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
# Profile memory usage with max_num_sequences sequences and the total for idx in range(max_num_seqs)
# number of tokens equal to max_num_batched_tokens. ]
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for multi-modal encoding, which # Profile memory usage with max_num_sequences sequences and the
# needs to be accounted for when calculating the GPU blocks for # total number of tokens equal to max_num_batched_tokens.
# vLLM blocker manager. seqs: List[SequenceGroupMetadata] = []
# To exercise the worst scenario for GPU memory consumption, # Additional GPU memory may be needed for multi-modal encoding,
# the number of seqs (batch_size) is chosen to maximize the number # which needs to be accounted for when calculating the GPU blocks
# of images processed. # for vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( # the number of seqs (batch_size) is chosen to maximize the number
self.model_config) # of images processed.
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
max_num_seqs = min(max_num_seqs, self.model_config)
max_num_batched_tokens // max_mm_tokens) if max_mm_tokens > 0:
if max_num_seqs < 1: max_num_seqs_orig = max_num_seqs
expr = (f"min({max_num_seqs_orig}, " max_num_seqs = min(max_num_seqs,
f"{max_num_batched_tokens} // {max_mm_tokens})") max_num_batched_tokens // max_mm_tokens)
logger.warning( if max_num_seqs < 1:
"Computed max_num_seqs (%s) to be less than 1. " expr = (f"min({max_num_seqs_orig}, "
"Setting it to the minimum value of 1.", expr) f"{max_num_batched_tokens} // {max_mm_tokens})")
max_num_seqs = 1 logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
batch_size = 0 "Setting it to the minimum value of 1.", expr)
for group_id in range(max_num_seqs): max_num_seqs = 1
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) batch_size = 0
batch_size += seq_len for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
dummy_data = self.input_registry \ (group_id < max_num_batched_tokens % max_num_seqs))
.dummy_data_for_profiling(self.model_config, batch_size += seq_len
seq_len,
self.mm_registry) dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq = SequenceGroupMetadata( seq_len,
request_id=str(group_id), self.mm_registry)
is_prompt=True,
seq_data={group_id: dummy_data.seq_data}, seq = SequenceGroupMetadata(
sampling_params=sampling_params, request_id=str(group_id),
block_tables=None, is_prompt=True,
lora_request=dummy_lora_requests_per_seq[group_id] seq_data={group_id: dummy_data.seq_data},
if dummy_lora_requests_per_seq else None, sampling_params=sampling_params,
multi_modal_data=dummy_data.multi_modal_data, block_tables=None,
multi_modal_placeholders=dummy_data.multi_modal_placeholders, lora_request=dummy_lora_requests_per_seq[group_id]
) if dummy_lora_requests_per_seq else None,
seqs.append(seq) multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.
# Run the model with the dummy inputs. multi_modal_placeholders,
num_layers = self.model_config.get_num_layers(self.parallel_config) )
# use an empty tensor instead of `None`` to force Dynamo to pass seqs.append(seq)
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as # Run the model with the dummy inputs.
# a placeholder (it has wide hardware support). num_layers = self.model_config.get_num_layers(self.parallel_config)
# it is important to create tensors inside the loop, rather than # use an empty tensor instead of `None`` to force Dynamo to pass
# multiplying the list, to avoid Dynamo from treating them as # it by reference, rather by specializing on the value ``None``.
# tensor aliasing. # the `dtype` argument does not matter, and we use `float32` as
kv_caches = [ # a placeholder (it has wide hardware support).
torch.tensor([], dtype=torch.float32, device=self.device) # it is important to create tensors inside the loop, rather than
for _ in range(num_layers) # multiplying the list, to avoid Dynamo from treating them as
] # tensor aliasing.
finished_requests_ids = [seq.request_id for seq in seqs] kv_caches = [
model_input = self.prepare_model_input( torch.tensor([], dtype=torch.float32, device=self.device)
seqs, finished_requests_ids=finished_requests_ids) for _ in range(num_layers)
intermediate_tensors = None ]
if not get_pp_group().is_first_rank: finished_requests_ids = [seq.request_id for seq in seqs]
intermediate_tensors = self.model.make_empty_intermediate_tensors( model_input = self.prepare_model_input(
batch_size=batch_size, seqs, finished_requests_ids=finished_requests_ids)
dtype=self.model_config.dtype, intermediate_tensors = None
device=self.device) if not get_pp_group().is_first_rank:
intermediate_tensors = \
self.execute_model(model_input, kv_caches, intermediate_tensors) self.model.make_empty_intermediate_tensors(
torch.cuda.synchronize() batch_size=batch_size,
return dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
def remove_all_loras(self): def remove_all_loras(self):
if not self.lora_manager: if not self.lora_manager:
......
...@@ -32,7 +32,7 @@ logger = init_logger(__name__) ...@@ -32,7 +32,7 @@ logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = [ MULTI_STEP_ATTENTION_BACKENDS = [
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
] ]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]: -> List[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment