"vscode:/vscode.git/clone" did not exist on "8aca27fa11bfe0539b72002761add7d990af325e"
Commit 052059d9 authored by guanyu1's avatar guanyu1
Browse files

detok修改

parent 2344d22e
...@@ -7,7 +7,7 @@ import random ...@@ -7,7 +7,7 @@ import random
import time import time
from functools import cache from functools import cache
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import os
import numpy as np import numpy as np
import torch import torch
import uvloop import uvloop
...@@ -179,7 +179,7 @@ def run_vllm( ...@@ -179,7 +179,7 @@ def run_vllm(
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TextPrompt(prompt=request.prompt, TextPrompt(prompt="helloworld",
multi_modal_data=request.multi_modal_data)) multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
...@@ -205,21 +205,25 @@ def run_vllm( ...@@ -205,21 +205,25 @@ def run_vllm(
dummy_prompts: List[PromptType] = [{ dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
print(f'{os.environ.get("VLLM_ZERO_OVERHEAD") == "1"}')
print("Warming up...") print("Warming up...")
for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"): for _ in tqdm(range(num_iters_warmup), desc="Warmup iterations"):
llm.generate(dummy_prompts, llm.generate(dummy_prompts,
sampling_params=warmup_sampling_params, sampling_params=warmup_sampling_params,
use_tqdm=False) use_tqdm=False)
use_beam_search = False
use_beam_search = False
print("testing")
if not use_beam_search: if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, outputs=llm.generate(prompts,
sampling_params, sampling_params,
lora_request=lora_requests, lora_request=lora_requests,
use_tqdm=True) use_tqdm=True)
for output in outputs:
generated_text=output.outputs[0].text
print(f"test生成的文本: {generated_text}")
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
......
...@@ -42,6 +42,57 @@ class StopChecker: ...@@ -42,6 +42,57 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet; # Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not # skip the stop string/token checks if not
if self.zero_overhead:
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
#new char count的 暂时未修改逻辑
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.zero_overhead_get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() >= sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else:
if seq.get_output_len() < sampling_params.min_tokens: if seq.get_output_len() < sampling_params.min_tokens:
return return
......
...@@ -177,7 +177,7 @@ class SequenceData(msgspec.Struct, ...@@ -177,7 +177,7 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
_first_step_flag: bool = True _first_step_flag: bool = True
_effective_length:int =0
@staticmethod @staticmethod
def from_prompt_token_counts( def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData": *token_counts: Tuple[int, int]) -> "SequenceData":
...@@ -311,11 +311,15 @@ class SequenceData(msgspec.Struct, ...@@ -311,11 +311,15 @@ class SequenceData(msgspec.Struct,
def get_len(self) -> int: def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids) return len(self._output_token_ids) + len(self._prompt_token_ids)
def zero_overhead_get_len(self) -> int:
return self._effective_length + len(self._prompt_token_ids)
def get_prompt_len(self) -> int: def get_prompt_len(self) -> int:
return len(self._prompt_token_ids) return len(self._prompt_token_ids)
def get_output_len(self) -> int: def get_output_len(self) -> int:
return len(self._output_token_ids) return len(self._output_token_ids)
def zero_overhead_get_output_len(self) -> Tuple[int, ...]:
return self._effective_length
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self._cached_all_token_ids return self._cached_all_token_ids
...@@ -372,6 +376,10 @@ class SequenceData(msgspec.Struct, ...@@ -372,6 +376,10 @@ class SequenceData(msgspec.Struct,
if not self._output_token_ids: if not self._output_token_ids:
return self._prompt_token_ids[-1] return self._prompt_token_ids[-1]
return self._output_token_ids[-1] return self._output_token_ids[-1]
def zero_overhead_get_last_token_id(self) -> int:
if self._effective_length==0:
return self._prompt_token_ids[-1]
return self._output_token_ids[self._effective_length-1]
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids return self.prompt_token_ids
...@@ -590,12 +598,16 @@ class Sequence: ...@@ -590,12 +598,16 @@ class Sequence:
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
def zero_overhead_get_len(self) -> int:
return self.data.zero_overhead_get_len()
def get_prompt_len(self) -> int: def get_prompt_len(self) -> int:
return self.data.get_prompt_len() return self.data.get_prompt_len()
def get_output_len(self) -> int: def get_output_len(self) -> int:
return self.data.get_output_len() return self.data.get_output_len()
def zero_overhead_get_output_len(self) -> int:
return self.data.zero_overhead_get_output_len()
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self.data.get_token_ids() return self.data.get_token_ids()
...@@ -604,7 +616,8 @@ class Sequence: ...@@ -604,7 +616,8 @@ class Sequence:
def get_last_token_id(self) -> int: def get_last_token_id(self) -> int:
return self.data.get_last_token_id() return self.data.get_last_token_id()
def zero_overhead_get_last_token_id(self) -> int:
return self.data.zero_overhead_get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> Tuple[int, ...]:
return self.data.get_output_token_ids() return self.data.get_output_token_ids()
......
...@@ -108,6 +108,9 @@ class Detokenizer: ...@@ -108,6 +108,9 @@ class Detokenizer:
The number of characters added to the output text. The number of characters added to the output text.
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
if self.zero_overhead:
all_input_ids = seq.get_token_ids()[:seq.get_prompt_len()+self.data._effective_length]
print(f'{all_input_ids=}')
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq) tokenizer = self.get_tokenizer_for_seq(seq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment