"vscode:/vscode.git/clone" did not exist on "6f1e7f7226447f606a0731376a2d0bd080aa2767"
Unverified Commit e0c15758 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Core] Modulize prepare input and attention metadata builder (#6596)

parent bdf5fd13
...@@ -7,7 +7,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, ...@@ -7,7 +7,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
...@@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]): ...@@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders.""" """Abstract class for attention metadata builders."""
@abstractmethod @abstractmethod
def __init__(self, input_builder) -> None: def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata", def build(self, seq_lens: List[int], query_lens: List[int],
token_lens: List[int], seq_lens: List[int], cuda_graph_pad_size: int, batch_size: int) -> T:
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError
@abstractmethod
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors.""" """Build attention metadata with on-device tensors."""
raise NotImplementedError raise NotImplementedError
......
...@@ -13,12 +13,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -13,12 +13,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import ModelInputForGPUBuilder
ModelInputForGPUBuilder)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -212,30 +210,30 @@ class FlashAttentionMetadataBuilder( ...@@ -212,30 +210,30 @@ class FlashAttentionMetadataBuilder(
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.use_v2_block_manager = ( self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager) input_builder.scheduler_config.use_v2_block_manager)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, def _add_seq_group(
token_lens: List[int], seq_lens: List[int], self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
curr_seq_lens: List[int], query_lens: List[int], chunked_prefill_enabled: bool):
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append """Add a sequence group to the metadata. Specifically update/append
1. context length. 1. context length.
2. block table. 2. block table.
3. slot mapping. 3. slot mapping.
""" """
is_prompt = seq_group_metadata.is_prompt is_prompt = inter_data.is_prompt
block_tables = seq_group_metadata.block_tables block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip( curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens, inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
curr_seq_lens, query_lens, context_lens, inter_data.orig_seq_lens, inter_data.seq_lens,
curr_sliding_window_blocks): inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
...@@ -254,7 +252,7 @@ class FlashAttentionMetadataBuilder( ...@@ -254,7 +252,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
block_table = [] block_table = []
if prefix_cache_hit: if inter_data.prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should # NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens. # include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id] block_table = block_tables[seq_id]
...@@ -270,16 +268,19 @@ class FlashAttentionMetadataBuilder( ...@@ -270,16 +268,19 @@ class FlashAttentionMetadataBuilder(
self.use_v2_block_manager) self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx, seq_len, context_len, start_idx,
self.block_size, self.block_size, inter_data.block_tables)
seq_group_metadata.block_tables)
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int): cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.""" """Build attention metadata with on-device tensors."""
device = runner.device for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(runner.model_config.hf_config, logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None) "attn_logit_softcapping", None)
if logits_soft_cap is not None: if logits_soft_cap is not None:
raise ValueError( raise ValueError(
...@@ -300,7 +301,7 @@ class FlashAttentionMetadataBuilder( ...@@ -300,7 +301,7 @@ class FlashAttentionMetadataBuilder(
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
......
...@@ -21,12 +21,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -21,12 +21,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import ModelInputForGPUBuilder
ModelInputForGPUBuilder)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -216,6 +214,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -216,6 +214,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.use_v2_block_manager = ( self.use_v2_block_manager = (
...@@ -238,26 +239,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -238,26 +239,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request # paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = [] self.paged_kv_last_page_len: List[int] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, def _add_seq_group(
token_lens: List[int], seq_lens: List[int], self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
curr_seq_lens: List[int], query_lens: List[int], chunked_prefill_enabled: bool):
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append """Add a sequence group to the metadata. Specifically update/append
1. context length. 1. context length.
2. block table. 2. block table.
3. slot mapping. 3. slot mapping.
""" """
is_prompt = seq_group_metadata.is_prompt is_prompt = inter_data.is_prompt
block_tables = seq_group_metadata.block_tables block_tables = inter_data.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums computed_block_nums = inter_data.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip( curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens, inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
curr_seq_lens, query_lens, context_lens, inter_data.orig_seq_lens, inter_data.seq_lens,
curr_sliding_window_blocks): inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
self.num_prefills += 1 self.num_prefills += 1
...@@ -275,7 +274,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -275,7 +274,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
block_table = [] block_table = []
if prefix_cache_hit: if inter_data.prefix_cache_hit:
block_table = computed_block_nums block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt) elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None): and block_tables is not None):
...@@ -290,8 +289,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -290,8 +289,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.use_v2_block_manager) self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx, seq_len, context_len, start_idx,
self.block_size, self.block_size, inter_data.block_tables)
seq_group_metadata.block_tables)
# It is not necessary to add paged_kv_indices, paged_kv_indptr, # It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will # and paged_kv_last_page_len for profile run because we will
...@@ -317,9 +315,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -317,9 +315,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
last_page_len = self.block_size last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len) self.paged_kv_last_page_len.append(last_page_len)
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int): cuda_graph_pad_size: int, batch_size: int):
device = runner.device for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens) max_query_len = max(query_lens)
...@@ -333,7 +335,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -333,7 +335,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
...@@ -377,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -377,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
logits_soft_cap = getattr(runner.model_config.hf_config, logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None) "attn_logit_softcapping", None)
if len(self.paged_kv_indptr) > 0: if len(self.paged_kv_indptr) > 0:
...@@ -394,8 +396,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -394,8 +396,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_tensor = None paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype, kv_cache_dtype = get_kv_cache_torch_dtype(
runner.model_config.dtype) self.runner.kv_cache_dtype, self.runner.model_config.dtype)
return FlashInferMetadata( return FlashInferMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
...@@ -406,11 +408,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -406,11 +408,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor, paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=runner.model_config.get_num_attention_heads( num_qo_heads=self.runner.model_config.get_num_attention_heads(
runner.parallel_config), self.runner.parallel_config),
num_kv_heads=runner.model_config.get_num_kv_heads( num_kv_heads=self.runner.model_config.get_num_kv_heads(
runner.parallel_config), self.runner.parallel_config),
head_dim=runner.model_config.get_head_size(), head_dim=self.runner.model_config.get_head_size(),
page_size=self.block_size, page_size=self.block_size,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
......
...@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union ...@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import torch import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
# Error string(s) for encoder/decoder # Error string(s) for encoder/decoder
...@@ -15,8 +14,7 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " ...@@ -15,8 +14,7 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
PAD_SLOT_ID = -1 PAD_SLOT_ID = -1
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import ModelInputForGPUBuilder
ModelInputForGPUBuilder)
def is_block_tables_empty(block_tables: Union[None, Dict]): def is_block_tables_empty(block_tables: Union[None, Dict]):
...@@ -95,26 +93,27 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -95,26 +93,27 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.use_v2_block_manager = ( self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager) input_builder.scheduler_config.use_v2_block_manager)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, def _add_seq_group(
token_lens: List[int], seq_lens: List[int], self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
curr_seq_lens: List[int], query_lens: List[int], chunked_prefill_enabled: bool):
context_lens: List[int], is_prompt = inter_data.is_prompt
curr_sliding_window_blocks: List[int], prefix_cache_hit, block_tables = inter_data.block_tables
chunked_prefill_enabled): computed_block_nums = inter_data.computed_block_nums
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip( curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens, inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
curr_seq_lens, query_lens, context_lens, inter_data.orig_seq_lens, inter_data.seq_lens,
curr_sliding_window_blocks): inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
self.num_prefills += 1 self.num_prefills += 1
...@@ -132,7 +131,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -132,7 +131,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
block_table = [] block_table = []
if prefix_cache_hit: if inter_data.prefix_cache_hit:
block_table = computed_block_nums block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt) elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None): and block_tables is not None):
...@@ -146,16 +145,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -146,16 +145,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.use_v2_block_manager) self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx, seq_len, context_len, start_idx,
self.block_size, self.block_size, inter_data.block_tables)
seq_group_metadata.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], device = self.runner.device
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(runner.model_config.hf_config, logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None) "attn_logit_softcapping", None)
if logits_soft_cap is not None: if logits_soft_cap is not None:
raise ValueError( raise ValueError(
...@@ -176,7 +177,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -176,7 +177,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
......
...@@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]], ...@@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict) return dict(merged_dict)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def init_cached_hf_modules() -> None: def init_cached_hf_modules() -> None:
""" """
Lazy initialization of the Hugging Face modules. Lazy initialization of the Hugging Face modules.
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment