Commit 6d0996e9 authored by guanyu1's avatar guanyu1
Browse files

detok补充

parent 3a49f3d2
...@@ -1515,6 +1515,8 @@ class LLMEngine: ...@@ -1515,6 +1515,8 @@ class LLMEngine:
if seq.seq_id == seq_id: if seq.seq_id == seq_id:
sample.output_token = token_id[0] sample.output_token = token_id[0]
seq.fix_last_token_id(sample.output_token) seq.fix_last_token_id(sample.output_token)
self.fix_process_model_output(ctx_output_queue,ctx_request_outputs,
ctx_multi_step_stream_outputs)
break break
def _advance_to_next_step( def _advance_to_next_step(
......
...@@ -6,7 +6,7 @@ from vllm.lora.request import LoRARequest ...@@ -6,7 +6,7 @@ from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
import os
class StopChecker: class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop """LLMEngine helper class which separates out the logic involving stop
...@@ -20,6 +20,8 @@ class StopChecker: ...@@ -20,6 +20,8 @@ class StopChecker:
# Do not use it directly, but use `self._get_max_model_len`. # Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def _get_max_model_len(self, lora_req: Optional[LoRARequest]): def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len: if lora_req and lora_req.long_lora_max_len:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional from typing import Dict, List, Optional
import os
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup) Sequence, SequenceGroup)
...@@ -16,6 +16,7 @@ class Detokenizer: ...@@ -16,6 +16,7 @@ class Detokenizer:
def __init__(self, tokenizer_group: BaseTokenizerGroup): def __init__(self, tokenizer_group: BaseTokenizerGroup):
self.tokenizer_group = tokenizer_group self.tokenizer_group = tokenizer_group
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence.""" """Returns the HF tokenizer to use for a given sequence."""
...@@ -109,7 +110,8 @@ class Detokenizer: ...@@ -109,7 +110,8 @@ class Detokenizer:
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
if self.zero_overhead: if self.zero_overhead:
all_input_ids = seq.get_token_ids()[:seq.get_prompt_len()+self.data._effective_length] eff_length=seq.get_prompt_len()+seq.data._effective_length
all_input_ids = seq.get_token_ids()[:eff_length]
print(f'{all_input_ids=}') print(f'{all_input_ids=}')
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq) tokenizer = self.get_tokenizer_for_seq(seq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment