Unverified Commit 66818e5b authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] separate builder init and builder prepare for each batch (#12253)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 222a9dc3
...@@ -65,11 +65,6 @@ class AttentionBackend(ABC): ...@@ -65,11 +65,6 @@ class AttentionBackend(ABC):
def get_builder_cls() -> Type["AttentionMetadataBuilder"]: def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError raise NotImplementedError
@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_kv_cache_shape( def get_kv_cache_shape(
...@@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]): ...@@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@abstractmethod @abstractmethod
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
...@@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder( ...@@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = [] self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
...@@ -388,11 +394,6 @@ class FlashAttentionMetadataBuilder( ...@@ -388,11 +394,6 @@ class FlashAttentionMetadataBuilder(
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.has_prefix_cache_hit = False self.has_prefix_cache_hit = False
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool): chunked_prefill_enabled: bool, prefix_cache_hit: bool):
......
...@@ -488,6 +488,14 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -488,6 +488,14 @@ class FlashInferMetadata(AttentionMetadata):
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = [] self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
...@@ -500,12 +508,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -500,12 +508,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields. # for the precise definition of the following fields.
# An example: # An example:
......
...@@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]): AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
def prepare(self):
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
...@@ -263,9 +268,6 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -263,9 +268,6 @@ class PlaceholderAttentionMetadataBuilder(
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool): chunked_prefill_enabled: bool):
......
...@@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): ...@@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data self.input_builder = input_builder
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
......
...@@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata] _metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"): def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = [] self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
...@@ -134,12 +141,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -134,12 +141,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool): chunked_prefill_enabled: bool):
......
...@@ -144,9 +144,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -144,9 +144,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
runner: "CPUModelRunner", runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None: finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__() super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner self.runner = runner
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching) or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls self.model_input_cls = self.runner._model_input_cls
...@@ -156,10 +154,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -156,10 +154,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.device = self.runner.device self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.enable_lora = self.runner.lora_config is not None self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData( self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope) self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( self.att_metadata_builder.prepare()
self)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata) self.seq_group_metadata_list.append(seq_group_metadata)
...@@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
""" """
_model_input_cls: Type[TModelInputForCPU] _model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder] _builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder
def __init__( def __init__(
self, self,
...@@ -477,6 +483,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -477,6 +483,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
# Set after load_model. # Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
...@@ -522,10 +532,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): ...@@ -522,10 +532,10 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
metadata for possible additional steps, e.g., sampling. metadata for possible additional steps, e.g., sampling.
""" """
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) self.builder.prepare(finished_requests_ids)
builder.set_seq_group_list(seq_group_metadata_list) self.builder.set_seq_group_list(seq_group_metadata_list)
return builder.build() # type: ignore return self.builder.build() # type: ignore
# sampler property will be used by spec_decode_worker # sampler property will be used by spec_decode_worker
@property @property
......
...@@ -457,17 +457,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -457,17 +457,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.enable_prompt_adapter = (self.runner.prompt_adapter_config self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None) is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True self.decode_only = True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
# Attention metadata inputs. # Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder( if self.attn_backend is not None:
weakref.proxy(self)) # spec decode (e.g. Medusa) does not have atten backend
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
# Engine/Model configurations. # Engine/Model configurations.
self.chunked_prefill_enabled = ( self.chunked_prefill_enabled = (
...@@ -479,6 +475,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -479,6 +475,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.block_aligned_sliding_window = \ self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size self.sliding_window_blocks * self.block_size
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []
self.attn_metadata_builder.prepare()
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens """Compute context length, sequence length and tokens
...@@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
""" """
_model_input_cls: Type[TModelInputForGPU] _model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder] _builder_cls: Type[ModelInputForGPUBuilder]
builder: ModelInputForGPUBuilder
def __init__( def __init__(
self, self,
...@@ -1093,6 +1101,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1093,6 +1101,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
SamplingMetadataCache() \ SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None if self.parallel_config.pipeline_parallel_size == 1 else None
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
...@@ -1226,13 +1238,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1226,13 +1238,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata) self.builder.add_seq_group(seq_group_metadata)
builder.reset_cached_inter_data() self.builder.reset_cached_inter_data()
return builder.build() # type: ignore return self.builder.build() # type: ignore
@contextmanager @contextmanager
def set_in_profile_run(self): def set_in_profile_run(self):
......
...@@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]): ...@@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects. """A builder to create ModelRunnerInputBase objects.
""" """
@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def add_seq_group(self, seq_group_metadata): def add_seq_group(self, seq_group_metadata):
"""TBA""" """TBA"""
......
...@@ -113,7 +113,6 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -113,7 +113,6 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
runner: "XPUModelRunner", runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None: finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__() super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner self.runner = runner
self.model_input_cls = self.runner._model_input_cls self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend self.attn_backend = self.runner.attn_backend
...@@ -121,6 +120,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -121,6 +120,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
self.block_size = self.runner.block_size self.block_size = self.runner.block_size
self.device = self.runner.device self.device = self.runner.device
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata) self.seq_group_metadata_list.append(seq_group_metadata)
...@@ -408,6 +411,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -408,6 +411,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
SamplingMetadataCache() \ SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None if self.parallel_config.pipeline_parallel_size == 1 else None
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None: def load_model(self) -> None:
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
...@@ -517,7 +522,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -517,7 +522,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
metadata for possible additional steps, e.g., sampling. metadata for possible additional steps, e.g., sampling.
""" """
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata) builder.add_seq_group(seq_group_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment