"vscode:/vscode.git/clone" did not exist on "fa023f32fbf24a0cac4f733ff876188ccf193a8f"
Unverified Commit 68486481 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Feat: deepseek-ocr logits processor (#12415)


Co-authored-by: default avatarxinyuant <xinyuant@usc.edu>
parent 410225b7
...@@ -15,6 +15,10 @@ from sglang.srt.multimodal.customized_mm_processor_utils import ( ...@@ -15,6 +15,10 @@ from sglang.srt.multimodal.customized_mm_processor_utils import (
register_customized_processor, register_customized_processor,
) )
from sglang.srt.sampling.custom_logit_processor import (
DeepseekOCRNoRepeatNGramLogitProcessor,
)
BASE_SIZE = 1024 BASE_SIZE = 1024
IMAGE_SIZE = 640 IMAGE_SIZE = 640
CROP_MODE = True CROP_MODE = True
...@@ -26,6 +30,24 @@ PRINT_NUM_VIS_TOKENS = False ...@@ -26,6 +30,24 @@ PRINT_NUM_VIS_TOKENS = False
SKIP_REPEAT = True SKIP_REPEAT = True
MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
NGRAM_NO_REPEAT_SIZE = 30
NGRAM_NO_REPEAT_WINDOW = 90
# Whitelist `<td>` and `</td>` token ids to allow table structures.
NGRAM_NO_REPEAT_WHITELIST = (128821, 128822)
DEFAULT_CUSTOM_LOGIT_PROCESSOR = DeepseekOCRNoRepeatNGramLogitProcessor.to_str()
def get_default_ngram_custom_params() -> Dict[str, Any]:
"""Return default custom params for the DeepSeek-OCR n-gram no repeat processor."""
return {
"ngram_size": NGRAM_NO_REPEAT_SIZE,
"window_size": NGRAM_NO_REPEAT_WINDOW,
"whitelist_token_ids": list(NGRAM_NO_REPEAT_WHITELIST),
}
PROMPT = "<image>\n<|grounding|>Convert the document to markdown." PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
......
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
import dill import dill
import orjson import orjson
...@@ -126,3 +126,69 @@ class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor): ...@@ -126,3 +126,69 @@ class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
THINKING_START_TOKEN_ID: int = 128798 THINKING_START_TOKEN_ID: int = 128798
THINKING_END_TOKEN_ID: int = 128799 THINKING_END_TOKEN_ID: int = 128799
NEW_LINE_TOKEN_ID: int = 201 NEW_LINE_TOKEN_ID: int = 201
# Adapted from DeepSeek's implementation: https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/ngram_norepeat.py
class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
"""Block n-gram repetitions within a sliding window for DeepSeek-OCR outputs."""
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
if not custom_param_list:
return logits
for batch_idx, params in enumerate(custom_param_list):
if not params:
continue
req = params.get("__req__")
if req is None:
continue
try:
ngram_size = int(params.get("ngram_size") or 0)
window_size = int(params.get("window_size") or 0)
except (TypeError, ValueError):
continue
if ngram_size <= 0 or window_size <= 0:
continue
sequence: List[int] = req.origin_input_ids + req.output_ids
if len(sequence) < ngram_size:
continue
search_start = max(0, len(sequence) - window_size)
search_end = len(sequence) - ngram_size + 1
if search_end <= search_start:
continue
if ngram_size > 1:
current_prefix = tuple(sequence[-(ngram_size - 1) :])
else:
current_prefix = tuple()
banned_tokens: Set[int] = set()
for idx in range(search_start, search_end):
ngram = sequence[idx : idx + ngram_size]
if ngram_size == 1 or tuple(ngram[:-1]) == current_prefix:
banned_tokens.add(ngram[-1])
whitelist_ids = params.get("whitelist_token_ids") or []
try:
whitelist = {int(token_id) for token_id in whitelist_ids}
except (TypeError, ValueError):
whitelist = set()
banned_tokens.difference_update(whitelist)
if not banned_tokens:
continue
indices = list(banned_tokens)
logits[batch_idx, indices] = -float("inf")
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment