Unverified Commit 15702038 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[Spec Decode] (1/2) Remove batch expansion (#8839)

parent 22f5851b
import functools import functools
from typing import Callable, List from typing import Callable, List, Optional
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
...@@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def process_outputs(self, def process_outputs(self,
sequence_group: SequenceGroup, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput], outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None: is_async: bool = False) -> Optional[int]:
"""Append new tokens in the outputs to sequences in the sequence group. """Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than This only supports sequence groups of size 1. It supports greater than
...@@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
tokens from the previous step. If this is true, then tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done no tokens need to be appended since it is already done
externally (before the next schedule() call) externally (before the next schedule() call)
Returns:
The number of tokens appended to the sequence. This is optional
because only speculative decode uses this return value.
""" """
# Sequences can be in RUNNING or FINISHED_ABORTED state # Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED # once scheduled, as a sequence is moved to FINSIHED_ABORTED
...@@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# was already appended, so we only need to do the rest of the # was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic # postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params) self._process_decode_and_stop(seq, sequence_group.sampling_params)
return None
else: else:
# Standard multi-step case # Standard multi-step case
...@@ -121,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -121,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
] ]
assert valid_samples assert valid_samples
self._process_seq_outputs(seq, valid_samples, return self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params) sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence, def _process_decode_and_stop(self, seq: Sequence,
...@@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_seq_outputs(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput], valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> int:
output_token_ids = [sample.output_token for sample in valid_samples] output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples]
...@@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids)) len(output_token_ids))
if remaining_tokens < 0: if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode # Truncate any tokens after EOS. This is required as spec decode
...@@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for i in range(len(output_token_ids)): for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id: if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1] output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break break
# Incrementally append tokens to the sequence, as if we had only one new # Incrementally append tokens to the sequence, as if we had only one new
...@@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
token_id=output_token_id, token_id=output_token_id,
logprobs=output_logprob, logprobs=output_logprob,
) )
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params) self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished(): if seq.is_finished():
break break
return len(output_token_ids)
...@@ -912,7 +912,7 @@ def get_logprobs( ...@@ -912,7 +912,7 @@ def get_logprobs(
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: SampleResultType, sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs. """Return sample logprobs and prompt logprobs.
The logic consists of 3 parts. The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and - Select indices to compute logprob from, ranks of token ids, and
......
...@@ -146,7 +146,7 @@ class SamplingMetadata: ...@@ -146,7 +146,7 @@ class SamplingMetadata:
def prepare( def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int], seq_lens: List[int],
query_lens: Optional[List[int]], query_lens: List[int],
device: str, device: str,
pin_memory: bool, pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[Dict[str, torch.Generator]] = None,
...@@ -194,7 +194,7 @@ class SamplingMetadata: ...@@ -194,7 +194,7 @@ class SamplingMetadata:
def _prepare_seq_groups( def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int], seq_lens: List[int],
query_lens: Optional[List[int]], query_lens: List[int],
device: str, device: str,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None, cache: Optional[SamplingMetadataCache] = None,
...@@ -284,7 +284,8 @@ def _prepare_seq_groups( ...@@ -284,7 +284,8 @@ def _prepare_seq_groups(
else: else:
# Decode # Decode
prompt_logprob_len = 0 prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0 query_len = query_lens[i] if query_lens is not None else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
if sampling_params.seed is not None and generators is not None: if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id) generator = generators.get(seq_group_metadata.request_id)
...@@ -440,14 +441,14 @@ class SamplingTensors: ...@@ -440,14 +441,14 @@ class SamplingTensors:
if seq_group.do_sample: if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices) sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids) assert sample_lens >= len(seq_ids)
temperatures += [temperature] * len(seq_ids) temperatures += [temperature] * sample_lens
top_ps += [top_p] * len(seq_ids) top_ps += [top_p] * sample_lens
top_ks += [top_k] * len(seq_ids) top_ks += [top_k] * sample_lens
min_ps += [min_p] * len(seq_ids) min_ps += [min_p] * sample_lens
presence_penalties += [p] * len(seq_ids) presence_penalties += [p] * sample_lens
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * sample_lens
repetition_penalties += [r] * len(seq_ids) repetition_penalties += [r] * sample_lens
if do_penalties: if do_penalties:
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
......
...@@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, ...@@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.worker.worker_base import WorkerBase
SeqId = int SeqId = int
TargetSeqId = int TargetSeqId = int
...@@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree. of topk/tree.
""" """
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@nvtx_range("BatchExpansionTop1Scorer.score_proposals") @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals( def score_proposals(
self, self,
......
...@@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
assert seq_group.is_prompt is False # No prompt assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step( def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata, self, model_input: ModelInputForGPUWithSamplingMetadata,
......
...@@ -5,6 +5,7 @@ from typing import Optional, Set ...@@ -5,6 +5,7 @@ from typing import Optional, Set
import torch import torch
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.worker.worker_base import WorkerBase
@dataclass @dataclass
...@@ -74,6 +75,12 @@ class SpeculativeProposer(ABC): ...@@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
class SpeculativeScorer(ABC): class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@abstractmethod @abstractmethod
def score_proposals( def score_proposals(
self, self,
......
from vllm.sequence import (ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
SeqId = int
TargetSeqId = int
class MQAScorer(SpeculativeScorer):
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
target_seq_group_metadata_list = []
target_seq_id_start = max(
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
all_proposal_tokens = proposals.proposal_token_ids.tolist()
for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data_dict = seq_group_metadata.seq_data
assert len(seq_data_dict) == 1
seq_id = next(iter(seq_data_dict.keys()))
seq_data: SequenceData = seq_data_dict[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
output_token_ids = seq_data.get_output_token_ids()
proposal_token_ids = all_proposal_tokens[i]
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
target_seq_id = target_seq_id_start + i
new_seq_data = SequenceData.from_seqs(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
)
new_seq_data.update_num_computed_tokens(
len(prompt_token_ids) + len(output_token_ids) - 1)
# Ensure that the new sequence has at least one token
# because we only use mqa scorer in the decoding stage.
assert len(output_token_ids) >= 1
new_seq_data_dict = {target_seq_id: new_seq_data}
new_seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
target_seq_group_metadata_list.append(new_seq_group_metadata)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
target_sampler_output = target_sampler_output[0]
bs, k = proposals.proposal_token_ids.shape
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
all_probs = target_sampler_output.sampled_token_probs.reshape(
bs, k + 1, self._vocab_size)
all_logprobs = target_sampler_output.logprobs.reshape(
bs, k + 1, self._vocab_size)
hidden_states = None
if target_sampler_output.hidden_states is not None:
hidden_states = target_sampler_output.hidden_states.reshape(
bs, (k + 1), -1)
return SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=hidden_states)
from collections import defaultdict from collections import defaultdict
from functools import cached_property from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch import torch
...@@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, ...@@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
...@@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
spec_decode_worker = SpecDecodeWorker.create_worker( spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker, scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs, draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
disable_by_batch_size=speculative_config. disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
draft_token_acceptance_method=speculative_config. draft_token_acceptance_method=speculative_config.
...@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
cls, cls,
scorer_worker: Worker, scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any], draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int], disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str, draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_threshold: float,
...@@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha, posterior_alpha=typical_acceptance_sampler_posterior_alpha,
) )
logger.info("Configuring SpecDecodeWorker with sampler=%s", logger.info(
type(spec_decode_sampler)) "[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn":
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if ngram_prompt_lookup_max > 0:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer.")
if "model_config" in draft_worker_kwargs and \
draft_worker_kwargs["model_config"].max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker( return SpecDecodeWorker(
proposer_worker, proposer_worker,
scorer_worker, scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs, disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size, disable_by_batch_size=disable_by_batch_size,
...@@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase, proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler, spec_decode_sampler: SpecDecodeBaseSampler,
disable_mqa_scorer: bool = False,
disable_logprobs: bool = False, disable_logprobs: bool = False,
disable_log_stats: bool = False, disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
...@@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler. instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker. not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both. If set to False, log probabilities will be output by both.
...@@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initialization. # Lazy initialization.
self.scorer: SpeculativeScorer self.scorer: SpeculativeScorer
self.disable_mqa_scorer = disable_mqa_scorer
# Hidden states from target model to pass to proposer # Hidden states from target model to pass to proposer
# in the subsequent step. # in the subsequent step.
...@@ -270,8 +308,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -270,8 +308,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer( scorer_cls: Type[SpeculativeScorer]
scorer_worker=self.scorer_worker, if self.disable_mqa_scorer:
scorer_cls = BatchExpansionTop1Scorer
logger.info("[Speculative Decoding] Use batch "
"expansion for scoring proposals.")
else:
scorer_cls = MQAScorer
logger.info(
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device, device=self.device,
vocab_size=self._vocab_size) vocab_size=self._vocab_size)
......
...@@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute context length (the number of tokens that are # Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens). # already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len() seq_len = seq_data.get_len()
if inter_data.is_prompt: if inter_data.is_prompt:
context_len = seq_data.get_num_computed_tokens() context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size) seq_len = min(seq_len, context_len + token_chunk_size)
elif self.runner.scheduler_config.is_multi_step or \
self.runner.model_config.is_encoder_decoder_model:
context_len = seq_len - 1
else:
context_len = seq_data.get_num_computed_tokens()
# Compute tokens. # Compute tokens.
if inter_data.is_prompt: tokens = seq_data.get_token_ids()[context_len:seq_len]
tokens = seq_data.get_token_ids()
if context_len != 0 or seq_len < len(tokens):
tokens = tokens[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = seq_data.get_last_token_id()
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
if isinstance(tokens, list):
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
else: inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.input_tokens[seq_idx].append(tokens) inter_data.query_lens[seq_idx] = seq_len - context_len
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
if seq_data.mrope_position_delta is not None: if seq_data.mrope_position_delta is not None:
if inter_data.mrope_input_positions is None: if inter_data.mrope_input_positions is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment