Commit 052059d9 authored by guanyu1's avatar guanyu1
Browse files

detok修改

parent 2344d22e
......@@ -7,7 +7,7 @@ import random
import time
from functools import cache
from typing import Dict, List, Optional, Tuple
import os
import numpy as np
import torch
import uvloop
......@@ -179,7 +179,7 @@ def run_vllm(
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
TextPrompt(prompt="helloworld",
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
......@@ -205,21 +205,25 @@ def run_vllm(
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
print(f'{os.environ.get("VLLM_ZERO_OVERHEAD") == "1"}')
print("Warming up...")
for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
llm.generate(dummy_prompts,
sampling_params=warmup_sampling_params,
use_tqdm=False)
use_beam_search = False
print("testing")
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts,
outputs=llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
for output in outputs:
generated_text=output.outputs[0].text
print(f"test生成的文本: {generated_text}")
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
......
......@@ -42,53 +42,104 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
if self.zero_overhead:
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
#new char count的 暂时未修改逻辑
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.zero_overhead_get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() >= sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else:
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def check_stop_strings(
......
......@@ -177,7 +177,7 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None
_first_step_flag: bool = True
_effective_length:int =0
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
......@@ -310,12 +310,16 @@ class SequenceData(msgspec.Struct,
def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
def zero_overhead_get_len(self) -> int:
return self._effective_length + len(self._prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self._prompt_token_ids)
def get_output_len(self) -> int:
return len(self._output_token_ids)
def zero_overhead_get_output_len(self) -> Tuple[int, ...]:
return self._effective_length
def get_token_ids(self) -> List[int]:
return self._cached_all_token_ids
......@@ -372,7 +376,11 @@ class SequenceData(msgspec.Struct,
if not self._output_token_ids:
return self._prompt_token_ids[-1]
return self._output_token_ids[-1]
def zero_overhead_get_last_token_id(self) -> int:
if self._effective_length==0:
return self._prompt_token_ids[-1]
return self._output_token_ids[self._effective_length-1]
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids
......@@ -589,13 +597,17 @@ class Sequence:
def get_len(self) -> int:
return self.data.get_len()
def zero_overhead_get_len(self) -> int:
return self.data.zero_overhead_get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int:
return self.data.get_output_len()
def zero_overhead_get_output_len(self) -> int:
return self.data.zero_overhead_get_output_len()
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
......@@ -604,7 +616,8 @@ class Sequence:
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def zero_overhead_get_last_token_id(self) -> int:
return self.data.zero_overhead_get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.data.get_output_token_ids()
......
......@@ -108,6 +108,9 @@ class Detokenizer:
The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
if self.zero_overhead:
all_input_ids = seq.get_token_ids()[:seq.get_prompt_len()+self.data._effective_length]
print(f'{all_input_ids=}')
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment