Commit 583034f1 authored by zhuwenwen's avatar zhuwenwen
Browse files

[models] support step3v

parent 0adf9cda
from typing import Union
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
class GPUToTensor(torch.nn.Module):
def forward(self, raw_image: Union[np.ndarray,
Image.Image]) -> torch.Tensor:
if isinstance(raw_image, Image.Image):
return transforms.ToTensor()(raw_image)
if raw_image.ndim == 2:
raw_image = raw_image[:, :, None].repeat(3, -1)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
image_tensor = torch.from_numpy(raw_image).to(device)
image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous()
if image_tensor.dtype == torch.uint8:
image_tensor = image_tensor.to(torch.float32).div(255)
return image_tensor
class StepPreprocessor:
def __init__(self, size, interpolation_mode="bicubic", patch_size=None):
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
patch_size = patch_size if patch_size is not None else size
self.transform = transforms.Compose([
GPUToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(size, size),
interpolation=InterpolationMode.BICUBIC if interpolation_mode
== "bicubic" else InterpolationMode.BILINEAR,
antialias=True),
])
self.patch_transform = transforms.Compose([
GPUToTensor(),
transforms.Normalize(mean, std),
transforms.Resize(
(patch_size, patch_size),
interpolation=InterpolationMode.BICUBIC if interpolation_mode
== "bicubic" else InterpolationMode.BILINEAR,
antialias=True),
]) if patch_size is not None else None
def preprocess(self, image, return_tensors="pt", is_patch=False): # noqa
if is_patch:
return {"pixel_values": self.patch_transform(image).unsqueeze(0)}
else:
return {
"pixel_values": self.transform(image).unsqueeze(0)
} # compatible with CLIPImageProcessor
\ No newline at end of file
...@@ -19,7 +19,7 @@ from vllm.logger import init_logger ...@@ -19,7 +19,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_base import (TokenizerBase, from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry) TokenizerRegistry)
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer, SentencePieceTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async from vllm.utils import make_async
...@@ -28,6 +28,11 @@ if TYPE_CHECKING: ...@@ -28,6 +28,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_TOKENIZER_REGISTRY = {
"step1": SentencePieceTokenizer,
"step2": SentencePieceTokenizer
}
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
TokenizerBase] TokenizerBase]
......
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids, validate_request_params) truncate_tool_call_ids, validate_request_params)
from vllm.transformers_utils.tokenizers.sentencepiece_tokenizer import (
SentencePieceTokenizer)
from vllm.transformers_utils.tokenizers.cpm_9g import CPM9GTokenizer from vllm.transformers_utils.tokenizers.cpm_9g import CPM9GTokenizer
__all__ = [ __all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids", "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids",
"validate_request_params", "validate_request_params",
"CPM9GTokenizer" "CPM9GTokenizer",
"SentencePieceTokenizer"
] ]
# SPDX-License-Identifier: Apache-2.0
# mypy: ignore-errors
import glob
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import sentencepiece
from vllm.transformers_utils.tokenizer_base import TokenizerBase
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ConversationMessage
@dataclass
class Encoding:
input_ids: List[int]
class SentencePieceTokenizer(TokenizerBase):
"""SentencePieceTokenizer"""
def __init__(self, model_file):
self.name = "SentencePieceTokenizer"
self.sp_model = sentencepiece.SentencePieceProcessor(
model_file=model_file)
# Set special tokens
self._special_tokens = {}
self._all_special_tokens = []
self._all_special_ids = []
self._vocab = {}
for idx in range(self.sp_model.get_piece_size()):
self._vocab[self.sp_model.id_to_piece(idx)] = idx
if not self.sp_model.is_control(idx):
continue
self._special_tokens[self.sp_model.id_to_piece(idx)] = idx
self._all_special_tokens.append(self.sp_model.id_to_piece(idx))
self._all_special_ids.append(idx)
self._special_tokens[self.sp_model.id_to_piece(
self.sp_model.unk_id())] = self.sp_model.unk_id()
self._all_special_tokens.append(
self.sp_model.id_to_piece(self.sp_model.unk_id()))
self._all_special_ids.append(self.sp_model.unk_id())
# FIXME: compatible for decode
self.length = self.sp_model.get_piece_size()
@property
def all_special_tokens_extended(self) -> List[str]:
return self._all_special_tokens
@property
def all_special_tokens(self) -> List[str]:
return self._all_special_tokens
@property
def all_special_ids(self) -> List[int]:
return self._all_special_ids
@property
def eos_token_id(self):
return self.sp_model.eos_id()
@property
def eos_token(self):
return self.sp_model.id_to_piece(self.eos_token_id)
@property
def bos_token_id(self):
return self.sp_model.bos_id()
@property
def unk_token_id(self):
return self.sp_model.unk_id()
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def vocab_size(self):
return self.length
@property
def is_fast(self) -> bool:
return True
@property
def max_token_id(self) -> int:
return self.sp_model.get_piece_size() - 1
def get_vocab(self):
return self._vocab
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(text)
if truncation:
input_ids = input_ids[:max_length]
return input_ids
def encode(self,
text: str,
add_special_tokens: bool = False,
add_bos: bool = True) -> List[int]:
if add_special_tokens:
# encode control token as normal string
parts = []
current_text = text
# Find all special tokens and their positions
token_positions = []
for token in self._special_tokens:
start = 0
while True:
idx = current_text.find(token, start)
if idx == -1:
break
token_positions.append((idx, token))
start = idx + 1
# Sort by position to process tokens in order they appear
token_positions.sort(key=lambda x: x[0])
# Process tokens in order
last_end = 0
for pos, token in token_positions:
# Add text before this token
if pos > last_end:
parts.append(current_text[last_end:pos])
# Add the token
parts.append({"token": token})
last_end = pos + len(token)
# Add remaining text after last token
if last_end < len(current_text):
parts.append(current_text[last_end:])
return self.encode_chatml(parts, add_bos=add_bos)
else:
return self.sp_model.encode(text, add_bos=add_bos)
def decode(self,
token_ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
elif isinstance(token_ids, list) and token_ids and not isinstance(token_ids[0], int):
token_ids = [int(token) for token in token_ids]
if skip_special_tokens:
# Original behavior: decode all tokens including special ones
return self.sp_model.decode(token_ids)
else:
# Handle special tokens separately
result = []
normal_tokens = []
for token_id in token_ids:
if token_id in self._all_special_ids:
# Decode any accumulated normal tokens first
if normal_tokens:
result.append(self.sp_model.decode(normal_tokens))
normal_tokens = []
# Add the special token as string
result.append(self.convert_id_to_token(token_id))
else:
# Accumulate normal tokens
normal_tokens.append(token_id)
# Decode any remaining normal tokens
if normal_tokens:
result.append(self.sp_model.decode(normal_tokens))
return ''.join(result)
def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
input_ids = self.encode(text, add_bos=True)
if truncation:
input_ids = input_ids[:max_length]
return Encoding(input_ids=input_ids)
def convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def convert_tokens_to_ids(self, tokens):
return self.sp_model.piece_to_id(tokens)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.sp_model.id_to_piece(index)
def convert_ids_to_tokens(
self, ids, **kwargs): # kwargs for compatibility of HF tokenizer
return self.sp_model.id_to_piece(ids)
def convert_tokens_to_string(self, tokens, skip_special_tokens=True):
# FIXME(ys): hack for tool call tokens
if skip_special_tokens:
return self.sp_model.decode(tokens)
else:
# Handle special tokens separately
result = []
normal_tokens = []
for token in tokens:
if token in self._all_special_tokens:
if normal_tokens:
result.append(self.sp_model.decode(normal_tokens))
normal_tokens = []
result.append(token)
else:
# Accumulate normal tokens
normal_tokens.append(token)
# Decode any remaining normal tokens
if normal_tokens:
result.append(self.sp_model.decode(normal_tokens))
return ''.join(result)
@classmethod
def from_pretrained(cls, model_path):
if model_path.endswith(".model"):
model_file = model_path
else:
possible_files = glob.glob(f"{model_path}/*.model")
if len(possible_files) != 1:
raise ValueError(
f"Expected exactly one .model file for tokenizer initialization in {model_path}, but found {possible_files}"
)
model_file = possible_files[0]
return cls(model_file=model_file)
def encode_chatml(self, input, add_bos=True):
input_ids = [self.bos_token_id] if add_bos else []
if isinstance(input, str):
input = [input]
# Compatible with the StepChat ChatML Protocol.
for subprompt in input:
if isinstance(subprompt, str):
subprompt_ids = self.encode(subprompt, add_bos=False)
input_ids += subprompt_ids
elif isinstance(subprompt, dict):
if "token" in subprompt:
input_ids += [self.convert_token_to_id(subprompt["token"])]
return input_ids
def get_added_vocab(self):
return None
def __len__(self):
return self.length
def apply_chat_template(self,
conversation: List["ConversationMessage"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
"""Convert chat messages to token IDs sequence.
Args:
conversation: List of chat messages
tools: Tool configurations (optional)
Returns:
List[int]: Sequence of token IDs
"""
ret = [self.bos_token_id]
continue_final_message = kwargs.get("continue_final_message", False)
# Handle tools parameter - properly insert tool_json_schemas message
processed_conversation = []
# Add tool schemas message following preprocess_python logic
if tools:
import json
tools_message = {
"role": "tool_json_schemas",
"content": json.dumps(tools, ensure_ascii=False) # tools should already be in the right format
}
if conversation and conversation[0]["role"] == "system":
processed_conversation.append(conversation[0])
processed_conversation.append(tools_message)
processed_conversation.extend(conversation[1:])
else:
processed_conversation.append(tools_message)
processed_conversation.extend(conversation)
else:
processed_conversation.extend(conversation)
history_tool_calls_map = {}
for message in processed_conversation:
# Add BOT token
ret.append(self._special_tokens["<|BOT|>"])
# Handle tool role - map to appropriate output role
if message["role"] == "tool" and message.get("tool_call_id") in history_tool_calls_map:
target_tool_call = history_tool_calls_map[message["tool_call_id"]]
role = f"{target_tool_call['type']}_output\n{target_tool_call['name']}"
else:
role = "human" if message["role"] == "user" else message["role"]
content = message.get("content") or ""
# Process message content
if isinstance(content, str):
text = f"{role}\n{content}"
ret.extend(self.encode(text, add_bos=False))
elif isinstance(content, list):
text = f"{role}\n"
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
ret.extend(
self.encode(text + item["text"], add_bos=False))
text = ""
elif isinstance(item,
dict) and item.get("type") == "image":
if text:
ret.extend(self.encode(text, add_bos=False))
text = ""
ret.append(self._special_tokens["<im_patch>"])
elif isinstance(item, dict) and item.get("type") == "audio":
if text:
ret.extend(self.encode(text, add_bos=False))
text = ""
ret.append(self._special_tokens["<audio_patch>"])
elif isinstance(item, dict) and item.get("type") == "audio_token":
if text:
ret.extend(self.encode(text, add_bos=False))
text = ""
ret.append(self._special_tokens["<audio_start>"])
ret.extend(self.encode(item["audio_token"], add_bos=False))
ret.append(self._special_tokens["<audio_end>"])
else:
# Handle other multimodal content types
if text:
ret.extend(self.encode(text, add_bos=False))
text = ""
# Handle tool calls
if message.get("tool_calls"):
for tool_call in message["tool_calls"]:
if tool_call.get("type") == "function":
type_name = "function"
function = tool_call.get("function")
if function is None:
raise ValueError(f"Function is not set for tool call {tool_call.get('id')}: {tool_call}")
name = function.get("name")
arguments = function.get("arguments")
if isinstance(arguments, str):
content = arguments
else:
import json
content = json.dumps(arguments, ensure_ascii=False)
elif tool_call.get("type") == "code_interpreter":
type_name = "code_interpreter"
code_interpreter = tool_call.get("code_interpreter")
if code_interpreter is None:
raise ValueError(
f"Code interpreter is not set for tool call {tool_call.get('id')}: {tool_call}"
)
name = code_interpreter.get("language")
content = code_interpreter.get("code")
else:
raise ValueError(
f"Unknown tool call type {tool_call.get('type')}, must be either 'function' or 'code_interpreter': {tool_call}"
)
# Store tool call info for mapping responses
history_tool_calls_map[tool_call.get("id")] = {
"type": type_name,
"name": name,
"content": content,
}
# Add tool call tokens
ret.append(self._special_tokens["<|CALL_START|>"])
tool_call_text = f"{type_name}\n{name}\n{content}"
ret.extend(self.encode(tool_call_text, add_bos=False))
ret.append(self._special_tokens["<|CALL_END|>"])
ret.append(self._special_tokens["<|EOT|>"])
# If the last message is not from assistant, add assistant prompt
if processed_conversation[-1]["role"] != "assistant" and not continue_final_message:
ret.append(self._special_tokens["<|BOT|>"])
ret.extend(self.encode("assistant\n", add_bos=False))
# If the last message is from assistant, remove the last EOT token
elif ret[-1] == self._special_tokens["<|EOT|>"]:
ret.pop()
return ret
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment