Unverified Commit 3521ba4f authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

parent 2d7bce9c
......@@ -44,7 +44,8 @@ class EngineArgs:
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
......@@ -322,6 +323,14 @@ class EngineArgs:
default=EngineArgs.max_context_len_to_capture,
help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
......@@ -492,7 +501,8 @@ class EngineArgs:
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init)
self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
......
......@@ -69,6 +69,9 @@ class LLM:
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
......@@ -90,7 +93,8 @@ class LLM:
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
......@@ -112,6 +116,7 @@ class LLM:
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
......
......@@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
assert seq_group.is_prompt, (
"Caller should ensure the sequence group is in a prefill stage.")
seq_ids = seq_group.seq_ids
subquery_len = seq_group.subquery_len
assert subquery_len is not None
query_len = seq_group.query_len
assert query_len is not None
# prompt has only 1 seq id.
assert len(seq_ids) == 1
seq_data = seq_group.seq_data[seq_ids[0]]
......@@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
prompt_tokens = seq_data.prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start = computed_len + 1
next_token_index_end = min(computed_len + subquery_len + 1,
next_token_index_end = min(computed_len + query_len + 1,
len(prompt_tokens))
next_prompt_tokens = prompt_tokens[
next_token_index_start:next_token_index_end]
......
......@@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
@dataclass
class SequenceGroupToSample:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
seq_ids: List[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
seq_data: Dict[int, SequenceData]
# The length of the prompt of the sequence group. None if it is in a decode
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
prompt_len: Optional[int]
# The length of the query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len.
subquery_len: Optional[int]
seq_len: Optional[int]
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of query_len <= seq_len if chunked
# prefill is enabled.
query_len: Optional[int]
# A random number generator for sampling.
generator: Optional[torch.Generator]
# True if the sequence group is in prefill stage. False if it is in a
......@@ -46,8 +55,8 @@ class SequenceGroupToSample:
if len(self.prompt_logprob_indices) > 0:
assert self.sampling_params.prompt_logprobs is not None
if self.is_prompt:
assert self.prompt_len is not None
assert self.subquery_len is not None
assert self.seq_len is not None
assert self.query_len is not None
class SamplingMetadata:
......@@ -94,8 +103,8 @@ class SamplingMetadata:
@staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
pin_memory: bool,
) -> "SamplingMetadata":
......@@ -104,8 +113,8 @@ class SamplingMetadata:
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
subquery_lens, device)
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
......@@ -137,8 +146,8 @@ class SamplingMetadata:
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
......@@ -146,9 +155,9 @@ def _prepare_seq_groups(
Args:
seq_group_metadata_list: A list of sequence group to batch.
prompt_lens: A list of prompt lens per sequence group.
seq_lens: A list of sequence lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
subquery_lens: A list of query lengths. Prompt lens include the length
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
......@@ -189,8 +198,8 @@ def _prepare_seq_groups(
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
prompt_len: Optional[int] = None
subquery_len: Optional[int] = None
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample
......@@ -203,12 +212,12 @@ def _prepare_seq_groups(
num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
assert subquery_lens is not None and prompt_lens is not None
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
assert query_lens is not None and seq_lens is not None
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (subquery_len - num_prefill_sample
if do_sample else subquery_len)
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
# Decode
......@@ -267,8 +276,8 @@ def _prepare_seq_groups(
seq_ids=seq_ids,
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
prompt_len=prompt_len,
subquery_len=subquery_len,
seq_len=seq_len,
query_len=query_len,
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
......@@ -367,8 +376,8 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
subquery_len = seq_group.subquery_len
assert subquery_len is not None
query_len = seq_group.query_len
assert query_len is not None
prefill_len = len(seq_group.prompt_logprob_indices)
temperatures += [temperature] * prefill_len
top_ps += [top_p] * prefill_len
......@@ -397,8 +406,8 @@ class SamplingTensors:
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
subquery_len = seq_group.subquery_len
assert subquery_len is not None
query_len = seq_group.query_len
assert query_len is not None
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
......
......@@ -80,7 +80,7 @@ class CPUModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
......@@ -92,15 +92,15 @@ class CPUModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
prompt_len = len(prompt_tokens)
seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) # Prompt token num
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))
input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
......@@ -109,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window)
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prompt_len):
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
......@@ -151,19 +151,19 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
num_prefills=len(prompt_lens),
seq_lens=seq_lens,
seq_lens_tensor=None,
max_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
def _prepare_decode(
......@@ -174,7 +174,7 @@ class CPUModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
......@@ -192,9 +192,9 @@ class CPUModelRunner:
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
......@@ -208,7 +208,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_context_len = max(context_lens)
max_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
......@@ -219,9 +219,9 @@ class CPUModelRunner:
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
......@@ -236,14 +236,14 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
max_context_len=max_context_len,
num_prefills=0,
prefill_metadata=None,
decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype,
)
......@@ -265,20 +265,20 @@ class CPUModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens,
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
# subquery_lens is not needed if chunked prefill is not
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens,
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
......@@ -300,7 +300,7 @@ class CPUModelRunner:
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
seq_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
......
......@@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
seq_lens: List[int]
query_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
......@@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
seq_lens=[],
query_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
......@@ -134,9 +134,8 @@ class ModelRunner:
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
if self.model_config is not None else 0)
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
......@@ -149,7 +148,7 @@ class ModelRunner:
self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
......@@ -218,7 +217,7 @@ class ModelRunner:
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_prompt(
self,
......@@ -231,9 +230,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = []
seq_lens: List[int] = []
context_lens: List[int] = []
subquery_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
......@@ -257,21 +256,19 @@ class ModelRunner:
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
computed_len = seq_data.get_num_computed_tokens()
context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = prefill_end
prompt_lens.append(prompt_len)
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
# NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
......@@ -285,25 +282,25 @@ class ModelRunner:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert computed_len == 0
assert context_len == 0
# actual prompt lens
context_lens.append(computed_len)
subquery_lens.append(prompt_len - computed_len)
context_lens.append(context_len)
query_lens.append(seq_len - context_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end)))
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len - computed_len
(seq_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data:
......@@ -313,24 +310,24 @@ class ModelRunner:
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert computed_len == 0, (
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prefill_end):
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
......@@ -340,9 +337,9 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens)
assert max_subquery_len > 0
max_query_len = max(query_lens)
max_seq_len = max(seq_lens)
assert max_query_len > 0
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
......@@ -369,40 +366,39 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(subquery_lens_tensor,
torch.cumsum(query_lens_tensor,
dim=0,
dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor,
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_prompt_len=max_prompt_len,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
......@@ -411,8 +407,8 @@ class ModelRunner:
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
......@@ -427,7 +423,7 @@ class ModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
......@@ -455,9 +451,9 @@ class ModelRunner:
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
......@@ -477,11 +473,10 @@ class ModelRunner:
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens)
max_context_len = max(context_lens)
use_captured_graph = (
not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
max_seq_len = max(seq_lens)
use_captured_graph = (not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
......@@ -489,21 +484,21 @@ class ModelRunner:
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens_tensor.shape[0] == len(input_tokens)
assert context_lens_tensor.shape[0] == len(input_positions)
assert context_lens_tensor.shape[0] == len(slot_mapping)
assert seq_lens_tensor.shape[0] == len(input_tokens)
assert seq_lens_tensor.shape[0] == len(input_positions)
assert seq_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
......@@ -525,14 +520,13 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
prompt_lens_tensor=None,
max_subquery_len=None,
max_context_len=max_context_len,
max_prompt_len=None,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens_tensor,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
......@@ -565,8 +559,8 @@ class ModelRunner:
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
seq_lens,
query_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
......@@ -583,13 +577,13 @@ class ModelRunner:
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, prompt_lens, subquery_lens,
self.device, self.pin_memory)
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
......@@ -886,7 +880,7 @@ class ModelRunner:
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size(
......@@ -908,14 +902,13 @@ class ModelRunner:
# Create dummy attn_metadata.
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
prompt_lens_tensor=None,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_prompt_len=None,
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_seq_len=self.max_seq_len_to_capture,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens[:batch_size],
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
......@@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
......@@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.
......
......@@ -52,7 +52,7 @@ class NeuronModelRunner:
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
prompt_lens: List[int] = []
seq_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -61,26 +61,26 @@ class NeuronModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
seq_len = len(prompt_tokens)
seq_lens.append(seq_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
max_seq_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
max_seq_len,
pad=0,
dtype=torch.long,
device=self.device)
......@@ -88,7 +88,7 @@ class NeuronModelRunner:
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens
return input_tokens, input_positions, input_block_ids, seq_lens
def _prepare_decode(
self,
......@@ -149,18 +149,18 @@ class NeuronModelRunner:
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
# subquery_lens is not needed if chunked prefill is not
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens,
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment