"vllm/vscode:/vscode.git/clone" did not exist on "0aa38d16f56327622c1689d7510171662757deee"
Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing_extensions import assert_never
......@@ -8,9 +8,10 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
......@@ -19,9 +20,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"]]
Optional["MultiModalDataDict"], Optional[Dict[str,
Any]]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"]]
Optional["MultiModalDataDict"],
Optional[Dict[str, Any]]]
class InputPreprocessor:
......@@ -71,20 +74,21 @@ class InputPreprocessor:
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
print_warning_once("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
print_warning_once("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
print_warning_once("Falling back on <BOS> for decoder start token "
"id because decoder start token id is not "
"available.")
dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id
......@@ -207,7 +211,7 @@ class InputPreprocessor:
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
......@@ -217,7 +221,7 @@ class InputPreprocessor:
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
......@@ -225,77 +229,89 @@ class InputPreprocessor:
* prompt
* prompt_token_ids
* multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
'''
parsed = parse_singleton_prompt(inputs)
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens":
prompt = None
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs)
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens":
prompt = None
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
mm_processor_kwargs: Dict[str, Any],
) -> EncoderDecoderInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if decoder_mm_data is not None:
raise ValueError(
"Multi-modality decoder inputs of encoder-decoder models are "
......@@ -308,10 +324,11 @@ class InputPreprocessor:
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs(
return EncoderDecoderInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
mm_processor_kwargs=mm_processor_kwargs,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
......@@ -319,13 +336,13 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
) -> EncoderDecoderInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
:class:`EncoderDecoderInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
......@@ -347,58 +364,67 @@ class InputPreprocessor:
Arguments:
* inputs: an input prompt
* prompt: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
* :class:`EncoderDecoderInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
else:
encoder_comps = self._extract_prompt_components(
inputs,
prompt,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs(
encoder_comps,
decoder_comps,
mm_processor_kwargs,
)
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
decoder_comps = None, None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
......@@ -407,55 +433,65 @@ class InputPreprocessor:
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
mm_processor_kwargs = prompt["mm_processor_kwargs"]
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
prompt,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs(
encoder_comps,
decoder_comps,
mm_processor_kwargs,
)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
) -> DecoderOnlyInputs:
(prompt, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
) -> DecoderOnlyInputs:
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
Process an input prompt into an :class:`DecoderOnlyInputs` instance.
Arguments:
* inputs: input prompt
* prompt: input prompt
* request_id
* lora_request
* prompt_adapter_request
Returns:
* :class:`LLMInputs` instance
* :class:`DecoderOnlyInputs` instance
'''
prompt_comps = self._extract_prompt_components(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
)
......@@ -467,14 +503,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
)
......@@ -486,27 +522,27 @@ class InputPreprocessor:
def preprocess(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Preprocess the input prompt."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
inputs,
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return self._process_decoder_only_prompt(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
......@@ -514,27 +550,27 @@ class InputPreprocessor:
async def preprocess_async(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
inputs,
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
......
......@@ -9,9 +9,10 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from .data import LLMInputs
from .data import DecoderOnlyInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -99,7 +100,7 @@ class _MultiModalCounts(UserDict):
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
"""Preprocess the inputs to the model."""
......@@ -133,7 +134,7 @@ class InputRegistry:
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
......@@ -185,16 +186,8 @@ class InputRegistry:
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type:
dummy_factory = self._dummy_encoder_factories_by_model_type[
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def dummy_data_for_profiling(
self,
......@@ -235,9 +228,9 @@ class InputRegistry:
num_tokens = seq_data.prompt_token_ids
if len(num_tokens) < seq_len:
if is_encoder_data:
logger.warning(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead.", seq_len, len(num_tokens))
print_warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(num_tokens)} tokens instead.")
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
......@@ -252,8 +245,11 @@ class InputRegistry:
return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs:
def _default_input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
"""The default input processor is a no-op."""
return inputs
......@@ -286,7 +282,7 @@ class InputRegistry:
.get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs:
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
"""
Apply an input processor to an instance of model inputs.
......@@ -301,8 +297,14 @@ class InputRegistry:
model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
processor, overrides=model_config.mm_processor_kwargs)
# Handle multimodal processor kwargs with priority:
# Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs"),
processor,
)
return processor(InputContext(model_config), inputs,
**mm_processor_kwargs)
......
......@@ -39,6 +39,9 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# Compressed Tensor
elif hasattr(base_layer, "weight_packed"):
return base_layer.weight_packed.device
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
......
......@@ -23,8 +23,10 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available
......@@ -232,6 +234,8 @@ class LoRAModel(AdapterModel):
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
......@@ -242,8 +246,8 @@ class LoRAModel(AdapterModel):
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules:
print(unexpected_modules, "modules")
if unexpected_modules and not is_regex_target_modules(
config["target_modules"], expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
......@@ -332,6 +336,12 @@ class LoRAModelManager(AdapterModelManager):
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping"))
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
# Dict instead of a Set for compatibility with LRUCache.
......@@ -437,12 +447,22 @@ class LoRAModelManager(AdapterModelManager):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
......@@ -460,6 +480,15 @@ class LoRAModelManager(AdapterModelManager):
module, self.lora_slots,
self.lora_config,
self.model.config))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
......@@ -478,9 +507,10 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
......@@ -541,6 +571,19 @@ class LoRAModelManager(AdapterModelManager):
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
return False
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
......
import os
from typing import List, Optional, Set, Tuple, Type
import re
from typing import List, Optional, Set, Tuple, Type, Union
import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
......@@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(load_modules: Union[str, List[str]],
expected_lora_modules: List[str]) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False
def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False
if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
......
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
......@@ -55,7 +56,7 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
return self.forward_native
if is_hip():
......
from typing import Optional, Union
from typing import Optional
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request
# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters
return request
......@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel
# These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
guided_regex: str
......
......@@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest,
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
......@@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if guided_options.guided_json:
schema = _normalize_json_schema_object(guided_options.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif guided_options.guided_choice:
if guided_params.json:
schema_dict = _normalize_json_schema_object(guided_params.json)
character_level_parser = JsonSchemaParser(schema_dict)
elif guided_params.choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_options.guided_choice])
elif guided_options.guided_regex:
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
[StringParser(choice) for choice in guided_params.choice])
elif guided_params.regex:
character_level_parser = RegexParser(guided_params.regex)
elif guided_params.grammar:
# CFG grammar not supported by LMFE
raise ValueError("Cannot construct a guided decoding logits processor"
" using the grammar option with the"
" lm_format_enforcer backend.")
elif guided_params.json_object:
# None means any json object
character_level_parser = JsonSchemaParser(None)
else:
......@@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
if isinstance(schema, str):
return json_loads(schema)
if isinstance(schema, dict):
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")
......
......@@ -5,16 +5,11 @@ from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
class GuidedDecodingMode(Enum):
......@@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest,
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(request)
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
return None
......@@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern)
mode, guided_params.whitespace_pattern)
def get_local_outlines_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(guided_options)
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
return None
return _get_logits_processor(guide, tokenizer, mode,
guided_options.guided_whitespace_pattern)
guided_params.whitespace_pattern)
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest]
guided_params: GuidedDecodingParams
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a
# named tool choice, do guided decoding
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
elif request.guided_json:
if isinstance(request.guided_json, dict):
if guided_params.json:
if isinstance(guided_params.json, dict):
# turn dict into hashable string
json = json_dumps(request.guided_json)
elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(request.guided_json.__signature__)
json = json_dumps(guided_params.json)
else:
json = request.guided_json
json = guided_params.json
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
return request.guided_regex, GuidedDecodingMode.REGEX
elif request.guided_choice:
elif guided_params.regex:
return guided_params.regex, GuidedDecodingMode.REGEX
elif guided_params.choice:
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
regex_escape(str(choice)) for choice in guided_params.choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_object"):
elif guided_params.grammar:
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
elif guided_params.json_object:
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else:
return None, None
......
......@@ -14,6 +14,33 @@ from vllm.model_executor.utils import set_weight_attrs
import vllm.envs as envs
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 7
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 128,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"192": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 128,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 8
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"6144": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
}
}
\ No newline at end of file
......@@ -10,17 +10,27 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int,
renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
......@@ -33,10 +43,12 @@ def single_marlin_moe(
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (torch.Tensor): The act_order indices.
- perm (torch.Tensor): The act_order input permutation.
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
- sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
......@@ -78,16 +90,34 @@ def single_marlin_moe(
max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
device=hidden_states.device,
requires_grad=False)
scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
if sort_indices is None:
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
block_size_m, True, False)
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
......@@ -96,17 +126,20 @@ def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
perm1: torch.Tensor,
perm2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -116,21 +149,22 @@ def fused_marlin_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation.
- perm2 (torch.Tensor): The second act_order input permutation.
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
......@@ -150,6 +184,20 @@ def fused_marlin_moe(
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
......@@ -170,14 +218,42 @@ def fused_marlin_moe(
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
if has_no_zp:
w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
......@@ -192,14 +268,15 @@ def fused_marlin_moe(
topk_weights,
topk_ids,
w1_scale,
w1_zeros,
g_idx1,
perm1,
sort_indices1,
workspace,
scalar_type,
scalar_type1,
M,
2 * N,
K,
True,
is_k_full,
E,
topk,
block_size_m,
......@@ -216,14 +293,15 @@ def fused_marlin_moe(
topk_weights,
topk_ids,
w2_scale,
w2_zeros,
g_idx2,
perm2,
sort_indices2,
workspace,
scalar_type,
scalar_type2,
M,
K,
N,
True,
is_k_full,
E,
topk,
block_size_m,
......
......@@ -320,6 +320,9 @@ def get_moe_configs(E: int, N: int,
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"), config_file_path)
return None
......
......@@ -19,10 +19,16 @@ class RMSNorm(CustomOp):
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward_native(
self,
......@@ -36,7 +42,23 @@ class RMSNorm(CustomOp):
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
......@@ -49,6 +71,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm import _custom_ops as ops
if residual is not None:
......@@ -89,6 +114,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm._ipex_ops import ipex_ops as ops
if residual is not None:
......
......@@ -30,7 +30,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod"
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
]
......@@ -355,8 +355,12 @@ class ColumnParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
param_data = param.data
if output_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
......@@ -459,17 +463,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return
if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return
param_data = param.data
output_dim = getattr(param, "output_dim", None)
......@@ -534,18 +544,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
......@@ -802,17 +800,23 @@ class QKVParallelLinear(ColumnParallelLinear):
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return
if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return
param_data = param.data
output_dim = getattr(param, "output_dim", None)
......@@ -840,6 +844,9 @@ class QKVParallelLinear(ColumnParallelLinear):
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
......@@ -853,6 +860,23 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size,
self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
......@@ -902,18 +926,6 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
......
......@@ -6,65 +6,57 @@ from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
activation: str = "silu",
):
def causal_conv1d_fn(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID):
"""
x: (batch, dim, seqlen)
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
final_states_out, activation
in ["silu", "swish"])
return (out, None) if not return_final_states else (out, final_states_out)
ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
cache_indices, has_initial_state, activation
in ["silu", "swish"], pad_slot_id)
return x
def causal_conv1d_update(x: torch.Tensor,
......@@ -72,21 +64,39 @@ def causal_conv1d_update(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None,
conv_state_indices: Optional[torch.Tensor] = None):
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias,
activation_bool, conv_state_indices)
activation_val = activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
cache_seqlens, conv_state_indices, pad_slot_id)
if unsqueeze:
x = x.squeeze(-1)
return x
......@@ -7,6 +7,7 @@ import triton.language as tl
from packaging import version
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
......@@ -48,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr,
out_ptr,
state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions
batch,
nheads,
......@@ -141,10 +143,11 @@ def _selective_scan_update_kernel(
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
state = tl.load(state_ptrs, mask=mask, other=0.0)
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
......@@ -175,9 +178,11 @@ def _selective_scan_update_kernel(
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
tl.store(state_ptrs,
state,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
......@@ -196,7 +201,8 @@ def selective_state_update(state,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None):
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
......@@ -208,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
......@@ -274,6 +286,7 @@ def selective_state_update(state,
z,
out,
state_batch_indices,
pad_slot_id,
batch,
nheads,
dim,
......@@ -318,6 +331,7 @@ def selective_state_update(state,
def selective_scan_fn(u,
ssm_states,
delta,
A,
B,
......@@ -326,11 +340,45 @@ def selective_scan_fn(u,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state index
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
"""
if u.stride(-1) != 1:
u = u.contiguous()
......@@ -344,28 +392,20 @@ def selective_scan_fn(u,
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
if B.dim() == 3 and query_start_loc is None:
B = B.unsqueeze(1)
if C.dim() == 3:
if B.dim() == 2 and query_start_loc is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and query_start_loc is None:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
query_start_loc, cache_indices, has_initial_state,
ssm_states, pad_slot_id)
if z is None:
return out if not return_last_state else (out, last_state)
return delta # output written inplace to delta
else:
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
return z # output written inplace to z
......@@ -11,6 +11,7 @@ from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
class Pooler(nn.Module):
......@@ -43,6 +44,12 @@ class Pooler(nn.Module):
if self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
offset = 0
pooled_data = []
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
......
......@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
......@@ -49,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment