Commit 675ba75f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-ori

parents 5cc98918 296c6572
......@@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod
def _init_executor(self) -> None:
......@@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep)
def wake_up(self):
def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up")
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
self.is_sleeping = False
logger.info("It took %.6f seconds to wake up.",
time_after_wakeup - time_before_wakeup)
logger.info("It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state(
self,
......
......@@ -79,7 +79,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
......@@ -546,10 +546,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
"Run `pip install ray[cgraph]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
if (cupy_spec is None
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. "
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
"Run `pip install ray[cgraph]` and check cupy installation.")
def _compiled_ray_dag(self, enable_asyncio: bool):
......@@ -557,10 +558,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._check_ray_cgraph_installation()
from ray.dag import InputNode, MultiOutputNode
logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if channel_type not in ("auto", "nccl", "shm"):
raise ValueError(
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
......@@ -605,13 +613,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
if (pp_rank < last_pp_rank and
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "nccl" \
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
# pp stage or when using shared memory (the default).
transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
outputs = [
output.with_tensor_transport(transport=transport)
for output in outputs
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload
from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
......@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
return "encoder" in inputs and "decoder" in inputs
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs
......@@ -261,13 +261,13 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
......@@ -288,14 +288,14 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
......@@ -528,6 +528,7 @@ class InputPreprocessor:
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
......@@ -536,6 +537,7 @@ class InputPreprocessor:
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
......
......@@ -13,13 +13,12 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
from .parse import split_enc_dec_inputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -329,17 +328,27 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.sequence import SequenceData
if mm_registry.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = mm_registry.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else
profiler.get_decoder_dummy_data)
dummy_data = dummy_data_factory(seq_len)
dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
if is_encoder_data else
profiler.get_decoder_dummy_data(seq_len))
_seq_data = SequenceData.from_seqs(
dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined]
dummy_data = DummyData(
seq_data=_seq_data,
multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
None),
multi_modal_placeholders=getattr(dummy_data_v1,
"multi_modal_placeholders",
None),
)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
......@@ -462,13 +471,11 @@ class InputRegistry:
**mm_processor_kwargs,
)
if is_encoder_decoder_inputs(processed_inputs):
self._ensure_mm_kwargs(processed_inputs["encoder"],
mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"],
mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
if encoder_inputs is not None:
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
if decoder_inputs is not None:
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
return processed_inputs
......
......@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device)
tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
......
......@@ -130,7 +130,7 @@ def do_expand_kernel(
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
offset_k[None, :] * input_d2_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
......
......@@ -136,6 +136,7 @@ def _lora_expand(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
......@@ -157,11 +158,19 @@ def _lora_expand(
identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
......@@ -170,6 +179,8 @@ def _lora_expand(
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
......@@ -181,7 +192,6 @@ def _lora_expand(
inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
......@@ -263,6 +273,7 @@ def _lora_expand_fake(
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
......
......@@ -17,6 +17,17 @@ class LoRAKernelMeta:
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod
def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta":
......@@ -47,17 +58,24 @@ class LoRAKernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32,
device=device)
no_lora_flag_cpu = torch.tensor([False],
dtype=torch.bool,
device='cpu')
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc)
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
......@@ -70,6 +88,14 @@ class LoRAKernelMeta:
self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
......@@ -100,7 +126,7 @@ class LoRAKernelMeta:
def meta_args(
self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
torch.Tensor, torch.Tensor]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
......@@ -111,7 +137,11 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward
pass.
"""
return (self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc,
self.active_lora_ids)
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)
......@@ -106,6 +106,7 @@ def _lora_shrink(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float,
) -> None:
"""
......@@ -126,8 +127,16 @@ def _lora_shrink(
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
......@@ -138,6 +147,8 @@ def _lora_shrink(
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
......@@ -146,7 +157,6 @@ def _lora_shrink(
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
......@@ -218,6 +228,7 @@ def _lora_shrink_fake(
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float,
) -> None:
return
......
......@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
......@@ -79,12 +79,6 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
......@@ -107,7 +101,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params)
......@@ -146,8 +144,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
......
......@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
"""
grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
......
......@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
......@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
......@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
......@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
......@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids)
input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids))
......@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.
......@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
......@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
......
......@@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
return None
elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer)
elif reasoning_backend == "granite":
logger.warning(
"Granite reasoner not yet implemented for structured outputs")
return None
else:
# Raise a warning for unknown reasoning backend and return None
# We cannot raise an error here because some reasoning models
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
......@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
......@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoner: Reasoner | None,
reasoner: ReasoningParser | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
......@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
reasoner: Reasoner | None = None
reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
......@@ -320,7 +320,10 @@ class XGrammarLogitsProcessor:
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar()
any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
......
......@@ -35,6 +35,8 @@ if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
......@@ -45,4 +47,5 @@ if HAS_TRITON:
"fused_experts",
"get_config_file_name",
"grouped_topk",
"cutlass_moe_fp8",
]
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment