"vscode:/vscode.git/clone" did not exist on "0313cf854d87a41c84efb69e89a79cd7b5897593"
Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
__all__ = [
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",
"JambaToolParser"
]
......@@ -53,7 +53,8 @@ class Hermes2ProToolParser(ToolParser):
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
......
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("jamba")
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
raise ValueError(
"Detected a MistralTokenizer tokenizer when using a Jamba model"
)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool_calls>"
self.tool_calls_end_token: str = "</tool_calls>"
self.tool_calls_regex = re.compile(
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token)
if (self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None):
raise RuntimeError(
"Jamba Tool parser could not locate tool calls start/end "
"tokens in the tokenizer!")
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because jamba use the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# use a regex to find the tool call between the tags
function_calls = self.tool_calls_regex.findall(model_output)[0]
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = json.loads(function_calls)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if
(len(content) > 0 and content != " ") else None)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.tool_calls_start_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the start of tool calls token which means
# the start of tool calling
if (self.tool_calls_start_token_id in delta_token_ids
and len(delta_token_ids) == 1):
# if it's the only token, return None, so we don't send a chat
# completion and don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
# Extract the tool calls between the special tool call tokens
parsable_arr = current_text.split(
self.tool_calls_start_token)[-1].split(
self.tool_calls_end_token)[0]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: List[Dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"")
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(new_text) +
len(new_text)]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
......@@ -63,7 +63,7 @@ class MistralToolParser(ToolParser):
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
......
......@@ -30,11 +30,13 @@ if TYPE_CHECKING:
VLLM_USAGE_SOURCE: str = ""
VLLM_CONFIGURE_LOGGING: int = 1
VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
......@@ -68,7 +70,9 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
def get_default_cache_root():
......@@ -223,7 +227,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
......@@ -273,6 +287,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_LEVEL":
lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"),
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
"VLLM_LOGGING_PREFIX":
lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
......@@ -293,6 +311,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
......@@ -451,6 +474,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
}
# end-env-vars-definition
......
......@@ -15,8 +15,10 @@ import vllm.envs as envs
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
"%(filename)s:%(lineno)d] %(message)s")
_DATE_FORMAT = "%m-%d %H:%M:%S"
DEFAULT_LOGGING_CONFIG = {
......
from functools import lru_cache
from typing import Dict, Type
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
from vllm.utils import is_hip, is_xpu, print_warning_once
logger = init_logger(__name__)
class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def __init__(self, *args, **kwargs):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
......@@ -17,7 +27,6 @@ class CustomOp(nn.Module):
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
......@@ -56,12 +65,16 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")
if not enabled:
return self.forward_native
if is_hip():
return self.forward_hip
elif is_cpu():
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_tpu():
return self.forward_tpu
......@@ -69,3 +82,50 @@ class CustomOp(nn.Module):
return self.forward_xpu
else:
return self.forward_cuda
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
if not hasattr(cls, "name"):
print_warning_once(
f"Custom op {cls.__name__} was not registered, "
f"which means it won't appear in the op registry. "
f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on()
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod
@lru_cache()
def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {}
# Decorator to register custom ops.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls
return decorator
......@@ -11,12 +11,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict
import vllm.envs as envs
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
......@@ -41,6 +43,7 @@ class FatreluAndMul(CustomOp):
return self.forward_native(x)
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
......@@ -78,6 +81,7 @@ class SiluAndMul(CustomOp):
return out
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
......@@ -133,6 +137,7 @@ class GeluAndMul(CustomOp):
return f'approximate={repr(self.approximate)}'
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
......@@ -154,6 +159,7 @@ class NewGELU(CustomOp):
return ops.gelu_new(x)
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
......@@ -174,8 +180,8 @@ class FastGELU(CustomOp):
return ops.gelu_fast(x)
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
......@@ -199,6 +205,7 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
......@@ -254,15 +261,24 @@ class ScaledActivation(nn.Module):
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}
_ACTIVATION_REGISTRY = LazyDict({
"gelu":
lambda: nn.GELU(),
"gelu_fast":
lambda: FastGELU(),
"gelu_new":
lambda: NewGELU(),
"gelu_pytorch_tanh":
lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"silu":
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
})
def get_act_fn(
......
......@@ -116,7 +116,7 @@ def single_marlin_moe(
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
......@@ -272,7 +272,7 @@ def fused_marlin_moe(
g_idx1,
sort_indices1,
workspace,
scalar_type1,
scalar_type1.id,
M,
2 * N,
K,
......@@ -297,7 +297,7 @@ def fused_marlin_moe(
g_idx2,
sort_indices2,
workspace,
scalar_type2,
scalar_type2.id,
M,
K,
N,
......
......@@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
......@@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
router_logits=router_logits,
......@@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
......@@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
......
......@@ -8,6 +8,7 @@ from vllm.model_executor.custom_op import CustomOp
import vllm.envs as envs
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
......@@ -27,7 +28,6 @@ class RMSNorm(CustomOp):
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward_native(
......@@ -139,6 +139,7 @@ class RMSNorm(CustomOp):
return s
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
......
......@@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module):
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
if sampling_metadata is not None:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
......@@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module):
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
......
......@@ -12,6 +12,7 @@ class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
CLS = 2
class Pooler(nn.Module):
......@@ -23,12 +24,13 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
pooling_type: The type of pooling to use (LAST, ALL, CLS).
normalize: Whether to normalize the pooled data.
"""
def __init__(self, pooling_type: PoolingType, normalize: bool):
super().__init__()
self.pooling_type = pooling_type
self.normalize = normalize
......@@ -38,10 +40,16 @@ class Pooler(nn.Module):
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools specific information from hidden states based on metadata."""
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
if self.pooling_type == PoolingType.LAST:
if self.pooling_type is PoolingType.CLS:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
dim=0)[:-1]
pooled_data = hidden_states[first_token_flat_indices]
elif self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
......
......@@ -100,12 +100,21 @@ class CompressedTensorsConfig(QuantizationConfig):
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
try:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["input_activations"] = None
if is_activation_quantization_format(quant_format):
input_activations = quant_config.get("input_activations")
# The only case where we have activation quant supported
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
# there is an input_quant but it is ignored
if not input_activations:
assert target_scheme_map[target][
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
......@@ -244,8 +253,6 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
......@@ -256,16 +263,19 @@ class CompressedTensorsConfig(QuantizationConfig):
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else:
# note: input_quant will be present for converted models;
# will be ignored during inference post loading
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
is_static_input_scheme=not input_quant.dynamic)
# note: input_quant can be None
if self._is_fp8_w8a16(weight_quant, input_quant):
is_static_input_scheme = (input_quant
and not input_quant.dynamic)
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
is_static_input_scheme=is_static_input_scheme)
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
......
......@@ -72,6 +72,7 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
......@@ -468,7 +469,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.long_factor = long_factor
scale = self.max_position_embeddings / \
self.original_max_position_embeddings
self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
......
......@@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import transformers
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
......@@ -108,39 +107,11 @@ class NeuronCasualLM(nn.Module):
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split"
if _is_pretrained_neuron_checkpoint(model_name_or_path):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
**kwargs)
self.model.to_neuron()
def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
# Checking if the neuron checkpoint is saved in the old format.
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
return True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files = ["config.json", "generation_config.json"]
pretrained_split_format = ".safetensors"
for file in pretrained_split_files:
file_path = os.path.join(model_name_or_path, file)
if not os.path.isfile(file_path):
return False
for file in os.listdir(model_name_or_path):
if file.endswith(pretrained_split_format):
return True
return False
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
......
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
class BertEmbedding(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = VocabParallelEmbedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = VocabParallelEmbedding(
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
class BertEncoder(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.layer = nn.ModuleList([
BertLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(len(self.layer)):
layer = self.layer[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
return hidden_states
class BertLayer(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.attention = BertAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
layer_norm_eps=config.layer_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.intermediate = BertIntermediate(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.intermediate")
self.output = BertOutput(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
):
attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output
class BertAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
layer_norm_eps: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.self = BertSelfAttention(hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.output")
self.output = BertSelfOutput(hidden_size=hidden_size,
layer_norm_eps=layer_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.output")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
self_output = self.self(hidden_states, kv_cache, attn_metadata)
return self.output(self_output, hidden_states)
class BertSelfAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
if not isinstance(self.attn.impl, XFormersImpl):
raise ValueError(
"Encoder-only models currently require XFORMERS attention "
"backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT.")
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_ONLY)
return output
class BertSelfOutput(nn.Module):
def __init__(self,
hidden_size: int,
layer_norm_eps: float,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor,
input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertIntermediate(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.intermediate_act_fn = get_act_fn(hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: int,
layer_norm_eps: float,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.dense = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense")
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor,
input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertModel(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
prefix=f"{prefix}.encoder")
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids)
return self.encoder(hidden_states, kv_caches, attn_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
("qkv_proj", "key", "k"),
("qkv_proj", "value", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "pooler" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class BertEmbeddingModel(nn.Module):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(
self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
......@@ -44,7 +44,7 @@ class EAGLE(nn.Module):
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "bias", False))
bias=getattr(self.config, "eagle_fc_bias", False))
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
......
......@@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": 0,
"inputs_embeds": 0,
"intermediate_tensors": 0,
})
@support_torch_compile
class Gemma2Model(nn.Module):
def __init__(
......
......@@ -97,6 +97,37 @@ class InternVisionEmbeddings(nn.Module):
return embeddings
class InternVisionPatchModel(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
if pixel_values is None and pixel_embeds is None:
raise ValueError(
'You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
elif pixel_values is not None:
if pixel_values.ndim == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(
f'wrong pixel_values size: {pixel_values.shape}')
return hidden_states
class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment