Unverified Commit 3521ba4f authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

parent 2d7bce9c
...@@ -44,7 +44,8 @@ class EngineArgs: ...@@ -44,7 +44,8 @@ class EngineArgs:
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
max_context_len_to_capture: int = 8192 max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0 tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray" tokenizer_pool_type: str = "ray"
...@@ -322,6 +323,14 @@ class EngineArgs: ...@@ -322,6 +323,14 @@ class EngineArgs:
default=EngineArgs.max_context_len_to_capture, default=EngineArgs.max_context_len_to_capture,
help='Maximum context length covered by CUDA ' help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce', parser.add_argument('--disable-custom-all-reduce',
action='store_true', action='store_true',
...@@ -492,7 +501,8 @@ class EngineArgs: ...@@ -492,7 +501,8 @@ class EngineArgs:
self.code_revision, self.tokenizer_revision, self.max_model_len, self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path, self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture, self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init) self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
......
...@@ -69,6 +69,9 @@ class LLM: ...@@ -69,6 +69,9 @@ class LLM:
disable CUDA graph and always execute the model in eager mode. disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode.
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
...@@ -90,7 +93,8 @@ class LLM: ...@@ -90,7 +93,8 @@ class LLM:
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -112,6 +116,7 @@ class LLM: ...@@ -112,6 +116,7 @@ class LLM:
swap_space=swap_space, swap_space=swap_space,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture, max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs, **kwargs,
) )
......
...@@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: ...@@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
assert seq_group.is_prompt, ( assert seq_group.is_prompt, (
"Caller should ensure the sequence group is in a prefill stage.") "Caller should ensure the sequence group is in a prefill stage.")
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
subquery_len = seq_group.subquery_len query_len = seq_group.query_len
assert subquery_len is not None assert query_len is not None
# prompt has only 1 seq id. # prompt has only 1 seq id.
assert len(seq_ids) == 1 assert len(seq_ids) == 1
seq_data = seq_group.seq_data[seq_ids[0]] seq_data = seq_group.seq_data[seq_ids[0]]
...@@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: ...@@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
prompt_tokens = seq_data.prompt_token_ids prompt_tokens = seq_data.prompt_token_ids
# +1 because we are looking for a next prompt token. # +1 because we are looking for a next prompt token.
next_token_index_start = computed_len + 1 next_token_index_start = computed_len + 1
next_token_index_end = min(computed_len + subquery_len + 1, next_token_index_end = min(computed_len + query_len + 1,
len(prompt_tokens)) len(prompt_tokens))
next_prompt_tokens = prompt_tokens[ next_prompt_tokens = prompt_tokens[
next_token_index_start:next_token_index_end] next_token_index_start:next_token_index_end]
......
...@@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558 ...@@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
@dataclass @dataclass
class SequenceGroupToSample: class SequenceGroupToSample:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step. # Sequence ids for the sequence group in a previous step.
seq_ids: List[int] seq_ids: List[int]
sampling_params: SamplingParams sampling_params: SamplingParams
# seq_id -> sequence data. # seq_id -> sequence data.
seq_data: Dict[int, SequenceData] seq_data: Dict[int, SequenceData]
# The length of the prompt of the sequence group. None if it is in a decode # The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage. # stage.
prompt_len: Optional[int] seq_len: Optional[int]
# The length of the query tokens to compute in the current step. None if it # The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len. # is in a decode stage. The length of query_len <= seq_len if chunked
subquery_len: Optional[int] # prefill is enabled.
query_len: Optional[int]
# A random number generator for sampling. # A random number generator for sampling.
generator: Optional[torch.Generator] generator: Optional[torch.Generator]
# True if the sequence group is in prefill stage. False if it is in a # True if the sequence group is in prefill stage. False if it is in a
...@@ -46,8 +55,8 @@ class SequenceGroupToSample: ...@@ -46,8 +55,8 @@ class SequenceGroupToSample:
if len(self.prompt_logprob_indices) > 0: if len(self.prompt_logprob_indices) > 0:
assert self.sampling_params.prompt_logprobs is not None assert self.sampling_params.prompt_logprobs is not None
if self.is_prompt: if self.is_prompt:
assert self.prompt_len is not None assert self.seq_len is not None
assert self.subquery_len is not None assert self.query_len is not None
class SamplingMetadata: class SamplingMetadata:
...@@ -94,8 +103,8 @@ class SamplingMetadata: ...@@ -94,8 +103,8 @@ class SamplingMetadata:
@staticmethod @staticmethod
def prepare( def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int], seq_lens: List[int],
subquery_lens: Optional[List[int]], query_lens: Optional[List[int]],
device: str, device: str,
pin_memory: bool, pin_memory: bool,
) -> "SamplingMetadata": ) -> "SamplingMetadata":
...@@ -104,8 +113,8 @@ class SamplingMetadata: ...@@ -104,8 +113,8 @@ class SamplingMetadata:
selected_token_indices, selected_token_indices,
categorized_sample_indices, categorized_sample_indices,
num_prompts, num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
subquery_lens, device) device)
selected_token_indices = async_tensor_h2d(selected_token_indices, selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=device, target_device=device,
...@@ -137,8 +146,8 @@ class SamplingMetadata: ...@@ -137,8 +146,8 @@ class SamplingMetadata:
def _prepare_seq_groups( def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int], seq_lens: List[int],
subquery_lens: Optional[List[int]], query_lens: Optional[List[int]],
device: str, device: str,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]: SamplingType, List[Tuple[int, int]]], int]:
...@@ -146,9 +155,9 @@ def _prepare_seq_groups( ...@@ -146,9 +155,9 @@ def _prepare_seq_groups(
Args: Args:
seq_group_metadata_list: A list of sequence group to batch. seq_group_metadata_list: A list of sequence group to batch.
prompt_lens: A list of prompt lens per sequence group. seq_lens: A list of sequence lens per sequence group.
Index of prompt len should match with seq_group_metadata_list. Index of prompt len should match with seq_group_metadata_list.
subquery_lens: A list of query lengths. Prompt lens include the length query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter. of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator, device: A device to use for random number generator,
`SequenceGroupToSample.generator`. `SequenceGroupToSample.generator`.
...@@ -189,8 +198,8 @@ def _prepare_seq_groups( ...@@ -189,8 +198,8 @@ def _prepare_seq_groups(
is_prompt = seq_group_metadata.is_prompt is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None. # If the current seq group is in decode stage, it is None.
prompt_len: Optional[int] = None seq_len: Optional[int] = None
subquery_len: Optional[int] = None query_len: Optional[int] = None
prompt_logprob_indices: List[int] = [] prompt_logprob_indices: List[int] = []
sample_indices: List[int] = [] sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample do_sample = seq_group_metadata.do_sample
...@@ -203,12 +212,12 @@ def _prepare_seq_groups( ...@@ -203,12 +212,12 @@ def _prepare_seq_groups(
num_prompts += 1 num_prompts += 1
num_prefill_sample = len(seq_ids) num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1 assert num_prefill_sample == 1
assert subquery_lens is not None and prompt_lens is not None assert query_lens is not None and seq_lens is not None
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from # If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob. # prompt logprob.
prompt_logprob_len = (subquery_len - num_prefill_sample prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else subquery_len) if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0 sample_len = num_prefill_sample if do_sample else 0
else: else:
# Decode # Decode
...@@ -267,8 +276,8 @@ def _prepare_seq_groups( ...@@ -267,8 +276,8 @@ def _prepare_seq_groups(
seq_ids=seq_ids, seq_ids=seq_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data, seq_data=seq_group_metadata.seq_data,
prompt_len=prompt_len, seq_len=seq_len,
subquery_len=subquery_len, query_len=query_len,
generator=generator, generator=generator,
is_prompt=is_prompt, is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices), prompt_logprob_indices=list(prompt_logprob_indices),
...@@ -367,8 +376,8 @@ class SamplingTensors: ...@@ -367,8 +376,8 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get # For tokens in the prompt that we only need to get
# their logprobs # their logprobs
subquery_len = seq_group.subquery_len query_len = seq_group.query_len
assert subquery_len is not None assert query_len is not None
prefill_len = len(seq_group.prompt_logprob_indices) prefill_len = len(seq_group.prompt_logprob_indices)
temperatures += [temperature] * prefill_len temperatures += [temperature] * prefill_len
top_ps += [top_p] * prefill_len top_ps += [top_p] * prefill_len
...@@ -397,8 +406,8 @@ class SamplingTensors: ...@@ -397,8 +406,8 @@ class SamplingTensors:
if is_prompt: if is_prompt:
prompt_best_of.append(sampling_params.best_of) prompt_best_of.append(sampling_params.best_of)
subquery_len = seq_group.subquery_len query_len = seq_group.query_len
assert subquery_len is not None assert query_len is not None
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
......
...@@ -80,7 +80,7 @@ class CPUModelRunner: ...@@ -80,7 +80,7 @@ class CPUModelRunner:
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
prompt_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
...@@ -92,15 +92,15 @@ class CPUModelRunner: ...@@ -92,15 +92,15 @@ class CPUModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids() prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens() computed_len = seq_data.get_num_computed_tokens()
prompt_len = len(prompt_tokens) seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) # Prompt token num seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids input_tokens.extend(prompt_tokens) # Token ids
# Token position ids # Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len))) input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data: if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append( multi_modal_input_list.append(
...@@ -109,15 +109,15 @@ class CPUModelRunner: ...@@ -109,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window). # where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and # For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot # block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0 start_idx = 0
if self.sliding_window is not None: if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prompt_len): for i in range(computed_len, seq_len):
if i < start_idx: if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
...@@ -151,19 +151,19 @@ class CPUModelRunner: ...@@ -151,19 +151,19 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
prompt_lens=prompt_lens, seq_lens=seq_lens,
num_prefills=len(prompt_lens), seq_lens_tensor=None,
max_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens, num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0, num_decode_tokens=0,
prefill_metadata=None, prefill_metadata=None,
decode_metadata=None, decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]), block_tables=torch.tensor([]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, prompt_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input) multi_modal_input)
def _prepare_decode( def _prepare_decode(
...@@ -174,7 +174,7 @@ class CPUModelRunner: ...@@ -174,7 +174,7 @@ class CPUModelRunner:
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
context_lens: List[int] = [] seq_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
...@@ -192,9 +192,9 @@ class CPUModelRunner: ...@@ -192,9 +192,9 @@ class CPUModelRunner:
position = seq_len - 1 position = seq_len - 1
input_positions.append(position) input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min( seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window) seq_len, self.sliding_window)
context_lens.append(context_len) seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
...@@ -208,7 +208,7 @@ class CPUModelRunner: ...@@ -208,7 +208,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table) block_tables.append(block_table)
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens, input_tokens = torch.tensor(input_tokens,
dtype=torch.long, dtype=torch.long,
...@@ -219,9 +219,9 @@ class CPUModelRunner: ...@@ -219,9 +219,9 @@ class CPUModelRunner:
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens = torch.tensor(context_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max( max_block_table_len = max(
len(block_table) for block_table in block_tables) len(block_table) for block_table in block_tables)
...@@ -236,14 +236,14 @@ class CPUModelRunner: ...@@ -236,14 +236,14 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=None, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=len(input_tokens), num_decode_tokens=len(input_tokens),
max_context_len=max_context_len,
num_prefills=0, num_prefills=0,
prefill_metadata=None, prefill_metadata=None,
decode_metadata=None, decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
...@@ -265,20 +265,20 @@ class CPUModelRunner: ...@@ -265,20 +265,20 @@ class CPUModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens, (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list) ) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list) attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] seq_lens = []
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
# subquery_lens is not needed if chunked prefill is not # query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill # supported. Since CPU worker doesn't support chunked prefill
# just use prompt_lens instead. # just use seq_lens instead.
prompt_lens, seq_lens,
self.device, self.device,
pin_memory=False) pin_memory=False)
# Broadcast the metadata. # Broadcast the metadata.
...@@ -300,7 +300,7 @@ class CPUModelRunner: ...@@ -300,7 +300,7 @@ class CPUModelRunner:
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
prompt_lens=None, seq_lens=None,
selected_token_indices=selected_token_indices, selected_token_indices=selected_token_indices,
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None, generators=None,
......
...@@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple): ...@@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens: List[int] input_tokens: List[int]
input_positions: List[int] input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage] attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int] seq_lens: List[int]
subquery_lens: List[int] query_lens: List[int]
lora_index_mapping: List[int] lora_index_mapping: List[int]
lora_prompt_mapping: List[int] lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest] lora_requests: Set[LoRARequest]
...@@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple): ...@@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens=[], input_tokens=[],
input_positions=[], input_positions=[],
attn_metadata=None, attn_metadata=None,
prompt_lens=[], seq_lens=[],
subquery_lens=[], query_lens=[],
lora_index_mapping=[], lora_index_mapping=[],
lora_prompt_mapping=[], lora_prompt_mapping=[],
lora_requests=set(), lora_requests=set(),
...@@ -134,9 +134,8 @@ class ModelRunner: ...@@ -134,9 +134,8 @@ class ModelRunner:
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture. int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = ( self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
self.model_config.max_context_len_to_capture if self.model_config is not None else 0)
if self.model_config is not None else 0)
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
...@@ -149,7 +148,7 @@ class ModelRunner: ...@@ -149,7 +148,7 @@ class ModelRunner:
self.model: torch.nn.Module # Set after load_model self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling. self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in # max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table # Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration. # in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be # The shape of the cached block table will be
...@@ -218,7 +217,7 @@ class ModelRunner: ...@@ -218,7 +217,7 @@ class ModelRunner:
def get_max_block_per_batch(self) -> int: def get_max_block_per_batch(self) -> int:
block_size = self.block_size block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_prompt( def _prepare_prompt(
self, self,
...@@ -231,9 +230,9 @@ class ModelRunner: ...@@ -231,9 +230,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = [] seq_lens: List[int] = []
context_lens: List[int] = [] context_lens: List[int] = []
subquery_lens: List[int] = [] query_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
...@@ -257,21 +256,19 @@ class ModelRunner: ...@@ -257,21 +256,19 @@ class ModelRunner:
token_chunk_size = seq_group_metadata.token_chunk_size token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
computed_len = seq_data.get_num_computed_tokens() context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption # We should use get_len here because in case of preemption
# it contains output tokens. # it contains output tokens.
prefill_end = min(seq_data.get_len(), seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
computed_len + token_chunk_size) prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] seq_lens.append(seq_len)
prompt_len = prefill_end
prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len( if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None: computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window # Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled: elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None: if seq_group_metadata.block_tables is not None:
...@@ -285,25 +282,25 @@ class ModelRunner: ...@@ -285,25 +282,25 @@ class ModelRunner:
prefix_block_tables.append([]) prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this # Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced. # assumption can be changed once chunked prefill is introduced.
assert computed_len == 0 assert context_len == 0
# actual prompt lens # actual prompt lens
context_lens.append(computed_len) context_lens.append(context_len)
subquery_lens.append(prompt_len - computed_len) query_lens.append(seq_len - context_len)
input_tokens.extend(prompt_tokens) input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end))) input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
if lora_id > 0: if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request) lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend( lora_prompt_mapping.extend(
[lora_id] * [lora_id] *
(prompt_len - computed_len (seq_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data: if seq_group_metadata.multi_modal_data:
...@@ -313,24 +310,24 @@ class ModelRunner: ...@@ -313,24 +310,24 @@ class ModelRunner:
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue continue
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window). # where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and # For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot # block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0 start_idx = 0
if self.sliding_window is not None: if self.sliding_window is not None:
assert computed_len == 0, ( assert context_len == 0, (
"Prefix caching is currently not supported with " "Prefix caching is currently not supported with "
"sliding window attention") "sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prefill_end): for i in range(context_len, seq_len):
if i < start_idx: if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
...@@ -340,9 +337,9 @@ class ModelRunner: ...@@ -340,9 +337,9 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
max_subquery_len = max(subquery_lens) max_query_len = max(query_lens)
max_prompt_len = max(prompt_lens) max_seq_len = max(seq_lens)
assert max_subquery_len > 0 assert max_query_len > 0
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
...@@ -369,40 +366,39 @@ class ModelRunner: ...@@ -369,40 +366,39 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill # Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached. # is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens, query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long, dtype=torch.int,
device=self.device) device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
torch.cumsum(subquery_lens_tensor, torch.cumsum(query_lens_tensor,
dim=0, dim=0,
dtype=subquery_start_loc.dtype, dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:]) out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
out=seq_start_loc[1:]) out=seq_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
prompt_lens=prompt_lens, seq_lens=seq_lens,
prompt_lens_tensor=prompt_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_subquery_len=max_subquery_len, max_query_len=max_query_len,
max_context_len=None, max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
subquery_start_loc=subquery_start_loc, subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
) )
...@@ -411,8 +407,8 @@ class ModelRunner: ...@@ -411,8 +407,8 @@ class ModelRunner:
input_tokens=input_tokens, input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
prompt_lens=prompt_lens, seq_lens=seq_lens,
subquery_lens=subquery_lens, query_lens=query_lens,
lora_index_mapping=lora_index_mapping, lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping, lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests, lora_requests=lora_requests,
...@@ -427,7 +423,7 @@ class ModelRunner: ...@@ -427,7 +423,7 @@ class ModelRunner:
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
context_lens: List[int] = [] seq_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
lora_index_mapping: List[int] = [] lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
...@@ -455,9 +451,9 @@ class ModelRunner: ...@@ -455,9 +451,9 @@ class ModelRunner:
position = seq_len - 1 position = seq_len - 1
input_positions.append(position) input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min( seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window) seq_len, self.sliding_window)
context_lens.append(context_len) seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
...@@ -477,11 +473,10 @@ class ModelRunner: ...@@ -477,11 +473,10 @@ class ModelRunner:
# See `capture_model` API for more details. # See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens. # For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens) batch_size = len(input_tokens)
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
use_captured_graph = ( use_captured_graph = (not self.model_config.enforce_eager
not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_seq_len <= self.max_seq_len_to_capture)
and max_context_len <= self.max_context_len_to_capture)
if use_captured_graph: if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
...@@ -489,21 +484,21 @@ class ModelRunner: ...@@ -489,21 +484,21 @@ class ModelRunner:
input_tokens.append(0) input_tokens.append(0)
input_positions.append(0) input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1) seq_lens.append(1)
block_tables.append([]) block_tables.append([])
lora_index_mapping.append(0) lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
context_lens_tensor = torch.tensor(context_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be # When using cuda-graph all these tensors should be
# padded. # padded.
assert context_lens_tensor.shape[0] == len(input_tokens) assert seq_lens_tensor.shape[0] == len(input_tokens)
assert context_lens_tensor.shape[0] == len(input_positions) assert seq_lens_tensor.shape[0] == len(input_positions)
assert context_lens_tensor.shape[0] == len(slot_mapping) assert seq_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
...@@ -525,14 +520,13 @@ class ModelRunner: ...@@ -525,14 +520,13 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
prompt_lens=None, seq_lens=None,
prompt_lens_tensor=None, seq_lens_tensor=seq_lens_tensor,
max_subquery_len=None, max_query_len=None,
max_context_len=max_context_len, max_seq_len=max_seq_len,
max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens_tensor, context_lens_tensor=None,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
) )
...@@ -565,8 +559,8 @@ class ModelRunner: ...@@ -565,8 +559,8 @@ class ModelRunner:
input_tokens, input_tokens,
input_positions, input_positions,
prefill_attn_metadata, prefill_attn_metadata,
prompt_lens, seq_lens,
subquery_lens, query_lens,
lora_index_mapping, lora_index_mapping,
lora_prompt_mapping, lora_prompt_mapping,
lora_requests, lora_requests,
...@@ -583,13 +577,13 @@ class ModelRunner: ...@@ -583,13 +577,13 @@ class ModelRunner:
decode_slot_mapping, decode_slot_mapping,
) = self._prepare_decode(decode_reqs) ) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, prompt_lens, subquery_lens, seq_group_metadata_list, seq_lens, query_lens, self.device,
self.device, self.pin_memory) self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled: if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0 assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens) num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens) num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens) num_decode_tokens = len(decode_input_tokens)
...@@ -886,7 +880,7 @@ class ModelRunner: ...@@ -886,7 +880,7 @@ class ModelRunner:
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID) slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size( graph_batch_size = _get_graph_batch_size(
...@@ -908,14 +902,13 @@ class ModelRunner: ...@@ -908,14 +902,13 @@ class ModelRunner:
# Create dummy attn_metadata. # Create dummy attn_metadata.
decode_metadata = self.attn_backend.make_metadata( decode_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
prompt_lens=None, seq_lens=None,
prompt_lens_tensor=None, seq_lens_tensor=seq_lens[:batch_size],
max_subquery_len=None, max_query_len=None,
max_context_len=self.max_context_len_to_capture, max_seq_len=self.max_seq_len_to_capture,
max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens[:batch_size], context_lens_tensor=None,
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
) )
...@@ -1025,7 +1018,7 @@ class CUDAGraphRunner: ...@@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
...@@ -1047,8 +1040,8 @@ class CUDAGraphRunner: ...@@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["context_lens"].copy_( self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True) attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_( self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph. # Run the graph.
......
...@@ -52,7 +52,7 @@ class NeuronModelRunner: ...@@ -52,7 +52,7 @@ class NeuronModelRunner:
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
input_block_ids: List[int] = [] input_block_ids: List[int] = []
prompt_lens: List[int] = [] seq_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -61,26 +61,26 @@ class NeuronModelRunner: ...@@ -61,26 +61,26 @@ class NeuronModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids() prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens) seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
input_tokens.append(prompt_tokens) input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len))) input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1 assert len(block_table) == 1
input_block_ids.append(block_table[0]) input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens) max_seq_len = max(seq_lens)
assert max_prompt_len > 0 assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len, max_seq_len,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_prompt_len, max_seq_len,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
...@@ -88,7 +88,7 @@ class NeuronModelRunner: ...@@ -88,7 +88,7 @@ class NeuronModelRunner:
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens return input_tokens, input_positions, input_block_ids, seq_lens
def _prepare_decode( def _prepare_decode(
self, self,
...@@ -149,18 +149,18 @@ class NeuronModelRunner: ...@@ -149,18 +149,18 @@ class NeuronModelRunner:
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_block_ids, (input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list) seq_lens) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list) input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] seq_lens = []
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens, seq_lens,
# subquery_lens is not needed if chunked prefill is not # query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill # supported. Since neuron worker doesn't support chunked prefill
# just use prompt_lens instead. # just use seq_lens instead.
prompt_lens, seq_lens,
self.device, self.device,
self.pin_memory) self.pin_memory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment