Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
import asyncio import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -8,9 +8,10 @@ from vllm.logger import init_logger ...@@ -8,9 +8,10 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
SingletonPromptInputs) SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -19,9 +20,11 @@ if TYPE_CHECKING: ...@@ -19,9 +20,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int], PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"]] Optional["MultiModalDataDict"], Optional[Dict[str,
Any]]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"]] Optional["MultiModalDataDict"],
Optional[Dict[str, Any]]]
class InputPreprocessor: class InputPreprocessor:
...@@ -71,20 +74,21 @@ class InputPreprocessor: ...@@ -71,20 +74,21 @@ class InputPreprocessor:
''' '''
if not self.is_encoder_decoder_model(): if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because " print_warning_once("Using None for decoder start token id because "
"this is not an encoder/decoder model.") "this is not an encoder/decoder model.")
return None return None
if (self.model_config is None or self.model_config.hf_config is None): if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because " print_warning_once("Using None for decoder start token id because "
"model config is not available.") "model config is not available.")
return None return None
dec_start_token_id = getattr(self.model_config.hf_config, dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None) 'decoder_start_token_id', None)
if dec_start_token_id is None: if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id " print_warning_once("Falling back on <BOS> for decoder start token "
"because decoder start token id is not available.") "id because decoder start token id is not "
"available.")
dec_start_token_id = self.get_bos_token_id() dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id return dec_start_token_id
...@@ -207,7 +211,7 @@ class InputPreprocessor: ...@@ -207,7 +211,7 @@ class InputPreprocessor:
def _extract_prompt_components( def _extract_prompt_components(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> PromptComponents:
...@@ -217,7 +221,7 @@ class InputPreprocessor: ...@@ -217,7 +221,7 @@ class InputPreprocessor:
Arguments: Arguments:
* request_id * request_id
* inputs: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
Returns: Returns:
...@@ -225,77 +229,89 @@ class InputPreprocessor: ...@@ -225,77 +229,89 @@ class InputPreprocessor:
* prompt * prompt
* prompt_token_ids * prompt_token_ids
* multi_modal_data * multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
''' '''
parsed = parse_singleton_prompt(inputs) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str": if parsed["type"] == "str":
prompt = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt = None prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"] prompt_text = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else: else:
assert_never(parsed) assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
async def _extract_prompt_components_async( async def _extract_prompt_components_async(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str": if parsed["type"] == "str":
prompt = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt = None prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"] prompt_text = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else: else:
assert_never(parsed) assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
encoder_comps: PromptComponents, encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents, decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs: mm_processor_kwargs: Dict[str, Any],
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps ) -> EncoderDecoderInputs:
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if decoder_mm_data is not None: if decoder_mm_data is not None:
raise ValueError( raise ValueError(
"Multi-modality decoder inputs of encoder-decoder models are " "Multi-modality decoder inputs of encoder-decoder models are "
...@@ -308,10 +324,11 @@ class InputPreprocessor: ...@@ -308,10 +324,11 @@ class InputPreprocessor:
decoder_prompt_ids, decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None))) force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs( return EncoderDecoderInputs(
prompt_token_ids=decoder_prompt_ids, prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt, prompt=decoder_prompt,
multi_modal_data=decoder_mm_data, multi_modal_data=decoder_mm_data,
mm_processor_kwargs=mm_processor_kwargs,
encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt, encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data, encoder_multi_modal_data=encoder_mm_data,
...@@ -319,13 +336,13 @@ class InputPreprocessor: ...@@ -319,13 +336,13 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderInputs:
''' '''
For encoder/decoder models only: For encoder/decoder models only:
Process an input prompt into an Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance. :class:`EncoderDecoderInputs` instance.
There are two types of input prompts: There are two types of input prompts:
singleton prompts which carry only the singleton prompts which carry only the
...@@ -347,58 +364,67 @@ class InputPreprocessor: ...@@ -347,58 +364,67 @@ class InputPreprocessor:
Arguments: Arguments:
* inputs: an input prompt * prompt: an input prompt
* request_id * request_id
Returns: Returns:
* :class:`EncoderDecoderLLMInputs` instance * :class:`EncoderDecoderInputs` instance
''' '''
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := inputs["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None decoder_comps = None, None, None, None
else: else:
decoder_comps = self._extract_prompt_components( decoder_comps = self._extract_prompt_components(
decoder_input, decoder_input,
request_id=request_id, request_id=request_id,
) )
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
else: else:
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
# If there are no decoder components, we assume the
decoder_comps = None, None, None # mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) -1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs(
encoder_comps,
decoder_comps,
mm_processor_kwargs,
)
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async( encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := inputs["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task encoder_comps = await encoder_task
decoder_comps = None, None, None decoder_comps = None, None, None, None
else: else:
decoder_task = self._extract_prompt_components_async( decoder_task = self._extract_prompt_components_async(
decoder_input, decoder_input,
...@@ -407,55 +433,65 @@ class InputPreprocessor: ...@@ -407,55 +433,65 @@ class InputPreprocessor:
encoder_comps, decoder_comps = await asyncio.gather( encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
mm_processor_kwargs = prompt["mm_processor_kwargs"]
else: else:
encoder_comps = await self._extract_prompt_components_async( encoder_comps = await self._extract_prompt_components_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
# If there are no decoder components, we assume the
decoder_comps = None, None, None # mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) -1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs(
encoder_comps,
decoder_comps,
mm_processor_kwargs,
)
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
self, self,
prompt_comps: PromptComponents, prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs: ) -> DecoderOnlyInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps (prompt, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = prompt_comps
prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request) prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids, return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt, prompt=prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> DecoderOnlyInputs:
''' '''
For decoder-only models: For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance. Process an input prompt into an :class:`DecoderOnlyInputs` instance.
Arguments: Arguments:
* inputs: input prompt * prompt: input prompt
* request_id * request_id
* lora_request * lora_request
* prompt_adapter_request * prompt_adapter_request
Returns: Returns:
* :class:`LLMInputs` instance * :class:`DecoderOnlyInputs` instance
''' '''
prompt_comps = self._extract_prompt_components( prompt_comps = self._extract_prompt_components(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -467,14 +503,14 @@ class InputPreprocessor: ...@@ -467,14 +503,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async( prompt_comps = await self._extract_prompt_components_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -486,27 +522,27 @@ class InputPreprocessor: ...@@ -486,27 +522,27 @@ class InputPreprocessor:
def preprocess( def preprocess(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]: ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models") "to decoder-only models")
# Decoder-only operation # Decoder-only operation
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
...@@ -514,27 +550,27 @@ class InputPreprocessor: ...@@ -514,27 +550,27 @@ class InputPreprocessor:
async def preprocess_async( async def preprocess_async(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]: ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Async version of :meth:`preprocess`.""" """Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models") "to decoder-only models")
# Decoder-only operation # Decoder-only operation
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
......
...@@ -9,9 +9,10 @@ from transformers import PretrainedConfig ...@@ -9,9 +9,10 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from .data import LLMInputs from .data import DecoderOnlyInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -99,7 +100,7 @@ class _MultiModalCounts(UserDict): ...@@ -99,7 +100,7 @@ class _MultiModalCounts(UserDict):
raise KeyError(msg) from exc raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
"""Preprocess the inputs to the model.""" """Preprocess the inputs to the model."""
...@@ -133,7 +134,7 @@ class InputRegistry: ...@@ -133,7 +134,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_token_counts((0, seq_len)) dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
dummy_multi_modal_data = None dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data return dummy_seq_data, dummy_multi_modal_data
...@@ -185,16 +186,8 @@ class InputRegistry: ...@@ -185,16 +186,8 @@ class InputRegistry:
return wrapper return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type: return self._dummy_encoder_factories_by_model_type \
dummy_factory = self._dummy_encoder_factories_by_model_type[ .get(model_cls, self._default_dummy_data_factory)
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
def dummy_data_for_profiling( def dummy_data_for_profiling(
self, self,
...@@ -235,9 +228,9 @@ class InputRegistry: ...@@ -235,9 +228,9 @@ class InputRegistry:
num_tokens = seq_data.prompt_token_ids num_tokens = seq_data.prompt_token_ids
if len(num_tokens) < seq_len: if len(num_tokens) < seq_len:
if is_encoder_data: if is_encoder_data:
logger.warning( print_warning_once(
"Expected at least %d dummy encoder tokens for profiling, " f"Expected at least {seq_len} dummy encoder tokens for "
"but found %d tokens instead.", seq_len, len(num_tokens)) f"profiling, but found {len(num_tokens)} tokens instead.")
else: else:
raise AssertionError( raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, " f"Expected at least {seq_len} dummy tokens for profiling, "
...@@ -252,8 +245,11 @@ class InputRegistry: ...@@ -252,8 +245,11 @@ class InputRegistry:
return seq_data, mm_data return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext, def _default_input_processor(
inputs: LLMInputs) -> LLMInputs: self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
"""The default input processor is a no-op.""" """The default input processor is a no-op."""
return inputs return inputs
...@@ -286,7 +282,7 @@ class InputRegistry: ...@@ -286,7 +282,7 @@ class InputRegistry:
.get(model_cls, self._default_input_processor) .get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig", def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs: inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
""" """
Apply an input processor to an instance of model inputs. Apply an input processor to an instance of model inputs.
...@@ -301,8 +297,14 @@ class InputRegistry: ...@@ -301,8 +297,14 @@ class InputRegistry:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls) processor = self._get_model_input_processor(model_cls)
mm_processor_kwargs = get_allowed_kwarg_only_overrides( # Handle multimodal processor kwargs with priority:
processor, overrides=model_config.mm_processor_kwargs) # Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs"),
processor,
)
return processor(InputContext(model_config), inputs, return processor(InputContext(model_config), inputs,
**mm_processor_kwargs) **mm_processor_kwargs)
......
...@@ -39,6 +39,9 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: ...@@ -39,6 +39,9 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear # unquantizedLinear
if hasattr(base_layer, "weight"): if hasattr(base_layer, "weight"):
return base_layer.weight.device return base_layer.weight.device
# Compressed Tensor
elif hasattr(base_layer, "weight_packed"):
return base_layer.weight_packed.device
# GPTQ/AWQ # GPTQ/AWQ
elif hasattr(base_layer, "qweight"): elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device return base_layer.qweight.device
......
...@@ -23,8 +23,10 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ...@@ -23,8 +23,10 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -232,6 +234,8 @@ class LoRAModel(AdapterModel): ...@@ -232,6 +234,8 @@ class LoRAModel(AdapterModel):
# modules. # modules.
unexpected_modules = [] unexpected_modules = []
target_modules = config["target_modules"] target_modules = config["target_modules"]
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules: for module in target_modules:
# Compatible with more modules, # Compatible with more modules,
# such as:layers.11.self_attn.k_proj # such as:layers.11.self_attn.k_proj
...@@ -242,8 +246,8 @@ class LoRAModel(AdapterModel): ...@@ -242,8 +246,8 @@ class LoRAModel(AdapterModel):
# expected_lora_modules. It is not reliable. See # expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no # https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism. # other better mechanism.
if unexpected_modules: if unexpected_modules and not is_regex_target_modules(
print(unexpected_modules, "modules") config["target_modules"], expected_lora_modules):
raise ValueError( raise ValueError(
f"While loading {lora_dir}, expected" f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
...@@ -332,6 +336,12 @@ class LoRAModelManager(AdapterModelManager): ...@@ -332,6 +336,12 @@ class LoRAModelManager(AdapterModelManager):
self.supported_lora_modules.append("rotary_emb") self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy( self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping) self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping"))
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
...@@ -437,12 +447,22 @@ class LoRAModelManager(AdapterModelManager): ...@@ -437,12 +447,22 @@ class LoRAModelManager(AdapterModelManager):
continue continue
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1] parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, []) packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule( new_module = replace_submodule(
self.model, module_name, self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config, from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config)) packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle # LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata. # long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
...@@ -460,6 +480,15 @@ class LoRAModelManager(AdapterModelManager): ...@@ -460,6 +480,15 @@ class LoRAModelManager(AdapterModelManager):
module, self.lora_slots, module, self.lora_slots,
self.lora_config, self.lora_config,
self.model.config)) self.model.config))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module) self.register_module(module_name, new_module)
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
...@@ -478,9 +507,10 @@ class LoRAModelManager(AdapterModelManager): ...@@ -478,9 +507,10 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor) model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance( if (not self._match_target_modules(module_name)
module, BaseLayerWithLoRA) or isinstance( or not isinstance(module, BaseLayerWithLoRA)
module, LinearScalingRotaryEmbeddingWithLora): or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or self._filter_unsupported_mm_module(module_name)):
continue continue
parts = module_name.split(".") parts = module_name.split(".")
if module_name not in self.packed_modules: if module_name not in self.packed_modules:
...@@ -541,6 +571,19 @@ class LoRAModelManager(AdapterModelManager): ...@@ -541,6 +571,19 @@ class LoRAModelManager(AdapterModelManager):
module_name) or target_module == module_name module_name) or target_module == module_name
for target_module in self.supported_lora_modules) for target_module in self.supported_lora_modules)
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
return False
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".") parts = module_full_name.split(".")
module_name = parts[-1] module_name = parts[-1]
......
import os import os
from typing import List, Optional, Set, Tuple, Type import re
from typing import List, Optional, Set, Tuple, Type, Union
import huggingface_hub import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
...@@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: ...@@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(load_modules: Union[str, List[str]],
expected_lora_modules: List[str]) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False
def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False
if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False
def get_adapter_absolute_path(lora_path: str) -> str: def get_adapter_absolute_path(lora_path: str) -> str:
""" """
Resolves the given lora_path to an absolute local path. Resolves the given lora_path to an absolute local path.
......
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu from vllm.utils import is_cpu, is_hip, is_xpu
...@@ -55,7 +56,7 @@ class CustomOp(nn.Module): ...@@ -55,7 +56,7 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one # NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching. # specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
return self.forward_native return self.forward_native
if is_hip(): if is_hip():
......
from typing import Optional, Union from typing import Optional
from vllm.entrypoints.openai.protocol import ( from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
async def get_guided_decoding_logits_processor( async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest, guided_params: GuidedDecodingParams,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request) # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor) get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor(
request, tokenizer) guided_params, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor) get_local_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor( return get_local_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer) guided_params, tokenizer)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer'")
def get_local_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest, guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request) # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor) get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor( return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer) guided_params, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor) get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor( return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer) guided_params, tokenizer)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request
# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters
return request
...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union ...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel from pydantic import BaseModel
# These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False): class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str] guided_json: Union[Dict, BaseModel, str]
guided_regex: str guided_regex: str
......
...@@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, ...@@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
TokenEnforcerTokenizerData, UnionParser) TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import ( from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def get_local_lm_format_enforcer_guided_decoding_logits_processor( def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
...@@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( ...@@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer) tokenizer)
character_level_parser: CharacterLevelParser character_level_parser: CharacterLevelParser
if guided_options.guided_json: if guided_params.json:
schema = _normalize_json_schema_object(guided_options.guided_json) schema_dict = _normalize_json_schema_object(guided_params.json)
character_level_parser = JsonSchemaParser(schema) character_level_parser = JsonSchemaParser(schema_dict)
elif guided_options.guided_choice: elif guided_params.choice:
character_level_parser = UnionParser( character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_options.guided_choice]) [StringParser(choice) for choice in guided_params.choice])
elif guided_options.guided_regex: elif guided_params.regex:
character_level_parser = RegexParser(guided_options.guided_regex) character_level_parser = RegexParser(guided_params.regex)
elif guided_options.guided_grammar: elif guided_params.grammar:
# CFG grammar not supported by LMFE, revert to outlines # CFG grammar not supported by LMFE
raise ValueError("Cannot construct a guided decoding logits processor"
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 " using the grammar option with the"
from vllm.model_executor.guided_decoding.outlines_decoding import ( " lm_format_enforcer backend.")
get_local_outlines_guided_decoding_logits_processor) elif guided_params.json_object:
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
# None means any json object # None means any json object
character_level_parser = JsonSchemaParser(None) character_level_parser = JsonSchemaParser(None)
else: else:
...@@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( ...@@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
return logits_processor return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
if isinstance(schema, str): if isinstance(schema, str):
return json_loads(schema) return json_loads(schema)
if isinstance(schema, dict): if isinstance(schema, dict):
return schema return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}") raise AssertionError(f"Unsupported schema type {schema}")
......
...@@ -5,16 +5,11 @@ from json import dumps as json_dumps ...@@ -5,16 +5,11 @@ from json import dumps as json_dumps
from re import escape as regex_escape from re import escape as regex_escape
from typing import Tuple, Union from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
class GuidedDecodingMode(Enum): class GuidedDecodingMode(Enum):
...@@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm ...@@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
we make a shallow copy to reuse the same underlying FSM. we make a shallow copy to reuse the same underlying FSM.
""" """
global global_thread_pool global global_thread_pool
guide, mode = _get_guide_and_mode(request) guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode: if not guide or not mode:
return None return None
...@@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
return await loop.run_in_executor(global_thread_pool, return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer, _get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern) mode, guided_params.whitespace_pattern)
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor( ...@@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
We cache logit processors by (guide, tokenizer), and on cache hit We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM. we make a shallow copy to reuse the same underlying FSM.
""" """
guide, mode = _get_guide_and_mode(guided_options) guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode: if not guide or not mode:
return None return None
return _get_logits_processor(guide, tokenizer, mode, return _get_logits_processor(guide, tokenizer, mode,
guided_options.guided_whitespace_pattern) guided_params.whitespace_pattern)
def _get_guide_and_mode( def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest, guided_params: GuidedDecodingParams
GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a if guided_params.json:
# named tool choice, do guided decoding if isinstance(guided_params.json, dict):
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
elif request.guided_json:
if isinstance(request.guided_json, dict):
# turn dict into hashable string # turn dict into hashable string
json = json_dumps(request.guided_json) json = json_dumps(guided_params.json)
elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(request.guided_json.__signature__)
else: else:
json = request.guided_json json = guided_params.json
return json, GuidedDecodingMode.JSON return json, GuidedDecodingMode.JSON
elif request.guided_regex: elif guided_params.regex:
return request.guided_regex, GuidedDecodingMode.REGEX return guided_params.regex, GuidedDecodingMode.REGEX
elif request.guided_choice: elif guided_params.choice:
# choice just uses regex # choice just uses regex
choices = [ choices = [
regex_escape(str(choice)) for choice in request.guided_choice regex_escape(str(choice)) for choice in guided_params.choice
] ]
choices_regex = "(" + "|".join(choices) + ")" choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar: elif guided_params.grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR return guided_params.grammar, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest) elif guided_params.json_object:
and request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else: else:
return None, None return None, None
......
...@@ -14,6 +14,33 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -14,6 +14,33 @@ from vllm.model_executor.utils import set_weight_attrs
import vllm.envs as envs import vllm.envs as envs
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 7
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 128,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"192": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 128,
"num_warps": 2,
"num_ctas": 1,
"num_stages": 8
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"6144": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_ctas": 1,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 16,
"num_ctas": 1,
"num_stages": 2
}
}
\ No newline at end of file
...@@ -10,17 +10,27 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -10,17 +10,27 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe( def single_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes the multiplication of hidden_states with expert This function computes the multiplication of hidden_states with expert
...@@ -33,10 +43,12 @@ def single_marlin_moe( ...@@ -33,10 +43,12 @@ def single_marlin_moe(
- scales (torch.Tensor): The quantization scales. - scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (torch.Tensor): The output of the gating operation
(before softmax). (before softmax).
- g_idx (torch.Tensor): The act_order indices. - g_idx (Optional[torch.Tensor]): Optional act_order indices.
- perm (torch.Tensor): The act_order input permutation. - sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization. - num_bits (bool): The number of bits in expert weights quantization.
...@@ -78,16 +90,34 @@ def single_marlin_moe( ...@@ -78,16 +90,34 @@ def single_marlin_moe(
max_workspace_size = (N // 64) * 16 max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device=hidden_states.device,
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8 has_zero_point = w_zeros is not None
if num_bits == 4 else scalar_types.uint8b128) if w_zeros is None:
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
if sort_indices is None:
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
block_size_m, True, False) is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
...@@ -96,17 +126,20 @@ def fused_marlin_moe( ...@@ -96,17 +126,20 @@ def fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
perm1: torch.Tensor,
perm2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -116,21 +149,22 @@ def fused_marlin_moe( ...@@ -116,21 +149,22 @@ def fused_marlin_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (torch.Tensor): The output of the gating operation
(before softmax). (before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices. - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation. - sort_indices1 (Optional[torch.Tensor]): The first act_order input
- perm2 (torch.Tensor): The second act_order input permutation. permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights. - topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements. - topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization. - num_bits (bool): The number of bits in expert weights quantization.
Returns: Returns:
...@@ -150,6 +184,20 @@ def fused_marlin_moe( ...@@ -150,6 +184,20 @@ def fused_marlin_moe(
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8] assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = w2.shape[1] * 16
...@@ -170,14 +218,42 @@ def fused_marlin_moe( ...@@ -170,14 +218,42 @@ def fused_marlin_moe(
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8 if has_no_zp:
if num_bits == 4 else scalar_types.uint8b128) w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
...@@ -192,14 +268,15 @@ def fused_marlin_moe( ...@@ -192,14 +268,15 @@ def fused_marlin_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
w1_zeros,
g_idx1, g_idx1,
perm1, sort_indices1,
workspace, workspace,
scalar_type, scalar_type1,
M, M,
2 * N, 2 * N,
K, K,
True, is_k_full,
E, E,
topk, topk,
block_size_m, block_size_m,
...@@ -216,14 +293,15 @@ def fused_marlin_moe( ...@@ -216,14 +293,15 @@ def fused_marlin_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
w2_scale, w2_scale,
w2_zeros,
g_idx2, g_idx2,
perm2, sort_indices2,
workspace, workspace,
scalar_type, scalar_type2,
M, M,
K, K,
N, N,
True, is_k_full,
E, E,
topk, topk,
block_size_m, block_size_m,
......
...@@ -320,6 +320,9 @@ def get_moe_configs(E: int, N: int, ...@@ -320,6 +320,9 @@ def get_moe_configs(E: int, N: int,
# If no optimized configuration is available, we will use the default # If no optimized configuration is available, we will use the default
# configuration # configuration
logger.warning(
("Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"), config_file_path)
return None return None
......
...@@ -19,10 +19,16 @@ class RMSNorm(CustomOp): ...@@ -19,10 +19,16 @@ class RMSNorm(CustomOp):
self, self,
hidden_size: int, hidden_size: int,
eps: float = 1e-6, eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.hidden_size = hidden_size
self.variance_epsilon = eps self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward_native( def forward_native(
self, self,
...@@ -36,7 +42,23 @@ class RMSNorm(CustomOp): ...@@ -36,7 +42,23 @@ class RMSNorm(CustomOp):
x = x + residual.to(torch.float32) x = x + residual.to(torch.float32)
residual = x.to(orig_dtype) residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True) hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight x = x.to(orig_dtype) * self.weight
if residual is None: if residual is None:
...@@ -49,6 +71,9 @@ class RMSNorm(CustomOp): ...@@ -49,6 +71,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
...@@ -89,6 +114,9 @@ class RMSNorm(CustomOp): ...@@ -89,6 +114,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
if residual is not None: if residual is not None:
......
...@@ -30,7 +30,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -30,7 +30,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod" "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
] ]
...@@ -355,8 +355,12 @@ class ColumnParallelLinear(LinearBase): ...@@ -355,8 +355,12 @@ class ColumnParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
param_data = param.data param_data = param.data
if output_dim is not None: # bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim] shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
...@@ -459,17 +463,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -459,17 +463,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.shard_weight_type[loaded_shard_id] = loaded_weight.item() param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return return
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight:
from gguf.constants import GGML_QUANT_SIZES tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
ori_shape = param.tensor_shape output_dim = getattr(param, "output_dim", None)
weight_types = self.qweight_type.shard_weight_type.values() shard_size = loaded_weight.size(output_dim) // tp_size
row_size = [] start_idx = tp_rank * shard_size
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type] loaded_weight = loaded_weight.narrow(output_dim, start_idx,
row_size.append(ori_shape[1] // block_size * type_size) shard_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
...@@ -534,18 +544,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -534,18 +544,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
...@@ -802,17 +800,23 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -802,17 +800,23 @@ class QKVParallelLinear(ColumnParallelLinear):
param.shard_weight_type[loaded_shard_id] = loaded_weight.item() param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return return
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight:
from gguf.constants import GGML_QUANT_SIZES tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // tp_size
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
ori_shape = param.tensor_shape param.shard_id.append(loaded_shard_id)
weight_types = self.qweight_type.shard_weight_type.values() param.shard_id_map[loaded_shard_id] = len(param.data_container)
row_size = [] param.data_container.append(loaded_weight)
for weight_type in weight_types: if len(param.data_container) == 3:
block_size, type_size = GGML_QUANT_SIZES[weight_type] self.qweight = param.materialize_nested()
row_size.append(ori_shape[1] // block_size * type_size) return
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
...@@ -840,6 +844,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -840,6 +844,9 @@ class QKVParallelLinear(ColumnParallelLinear):
("v", (self.total_num_heads + self.total_num_kv_heads) * ("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size), self.head_size, self.total_num_kv_heads * self.head_size),
] ]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets: for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights. # Special case for Quantized Weights.
...@@ -853,6 +860,23 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -853,6 +860,23 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size,
self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size) output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
...@@ -902,18 +926,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -902,18 +926,6 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
......
...@@ -6,65 +6,57 @@ from typing import Optional ...@@ -6,65 +6,57 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_fn( def causal_conv1d_fn(x: torch.Tensor,
x: torch.Tensor, weight: torch.Tensor,
weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None, cache_indices: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None,
return_final_states: bool = False, conv_states: Optional[torch.Tensor] = None,
final_states_out=None, activation: Optional[str] = "silu",
activation: str = "silu", pad_slot_id: int = PAD_SLOT_ID):
):
""" """
x: (batch, dim, seqlen) x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width) weight: (dim, width)
bias: (dim,) bias: (dim,)
seq_idx: (batch, seqlen) query_start_loc: (batch + 1) int32
initial_states: (batch, dim, width - 1) The cumulative sequence lengths of the sequences in
final_states_out: (batch, dim, width - 1), to be written to the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish" activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen) out: (batch, dim, seqlen)
""" """
if activation not in [None, "silu", "swish"]: if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish") raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1: if x.stride(-1) != 1:
x = x.contiguous() x = x.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states, ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
final_states_out, activation cache_indices, has_initial_state, activation
in ["silu", "swish"]) in ["silu", "swish"], pad_slot_id)
return (out, None) if not return_final_states else (out, final_states_out) return x
def causal_conv1d_update(x: torch.Tensor, def causal_conv1d_update(x: torch.Tensor,
...@@ -72,21 +64,39 @@ def causal_conv1d_update(x: torch.Tensor, ...@@ -72,21 +64,39 @@ def causal_conv1d_update(x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None, activation: Optional[str] = None,
conv_state_indices: Optional[torch.Tensor] = None): cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID):
""" """
x: (batch, dim) x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, width) conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width) weight: (dim, width)
bias: (dim,) bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32 conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim, If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices. and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario. Useful for a continuous batching scenario.
pad_slot_id: int
out: (batch, dim) if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
""" """
if activation not in [None, "silu", "swish"]: if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish") raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"] activation_val = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias, unsqueeze = x.dim() == 2
activation_bool, conv_state_indices) if unsqueeze:
x = x.unsqueeze(-1)
ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
cache_seqlens, conv_state_indices, pad_slot_id)
if unsqueeze:
x = x.squeeze(-1)
return x
...@@ -7,6 +7,7 @@ import triton.language as tl ...@@ -7,6 +7,7 @@ import triton.language as tl
from packaging import version from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
...@@ -48,6 +49,7 @@ def _selective_scan_update_kernel( ...@@ -48,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr, z_ptr,
out_ptr, out_ptr,
state_batch_indices_ptr, state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions # Matrix dimensions
batch, batch,
nheads, nheads,
...@@ -141,10 +143,11 @@ def _selective_scan_update_kernel( ...@@ -141,10 +143,11 @@ def _selective_scan_update_kernel(
if HAS_Z: if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
state = tl.load(state_ptrs, mask=mask, other=0.0)
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM: if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
...@@ -175,9 +178,11 @@ def _selective_scan_update_kernel( ...@@ -175,9 +178,11 @@ def _selective_scan_update_kernel(
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None] state = state * dA + dB * x[:, None]
tl.store(state_ptrs,
state, mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1) out = tl.sum(state * C[None, :], axis=1)
if HAS_D: if HAS_D:
out += x * D out += x * D
...@@ -196,7 +201,8 @@ def selective_state_update(state, ...@@ -196,7 +201,8 @@ def selective_state_update(state,
z=None, z=None,
dt_bias=None, dt_bias=None,
dt_softplus=False, dt_softplus=False,
state_batch_indices=None): state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID):
""" """
Argument: Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate) state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
...@@ -208,6 +214,12 @@ def selective_state_update(state, ...@@ -208,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim) D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim) z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim) dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return: Return:
out: (batch, dim) or (batch, nheads, dim) out: (batch, dim) or (batch, nheads, dim)
""" """
...@@ -274,6 +286,7 @@ def selective_state_update(state, ...@@ -274,6 +286,7 @@ def selective_state_update(state,
z, z,
out, out,
state_batch_indices, state_batch_indices,
pad_slot_id,
batch, batch,
nheads, nheads,
dim, dim,
...@@ -318,6 +331,7 @@ def selective_state_update(state, ...@@ -318,6 +331,7 @@ def selective_state_update(state,
def selective_scan_fn(u, def selective_scan_fn(u,
ssm_states,
delta, delta,
A, A,
B, B,
...@@ -326,11 +340,45 @@ def selective_scan_fn(u, ...@@ -326,11 +340,45 @@ def selective_scan_fn(u,
z=None, z=None,
delta_bias=None, delta_bias=None,
delta_softplus=False, delta_softplus=False,
return_last_state=False, query_start_loc=None,
position_indices=None, cache_indices=None,
prev_state=None): has_initial_state=None,
"""if return_last_state is True, returns (out, last_state) pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
last_state has shape (batch, dim, dstate). """
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state index
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
""" """
if u.stride(-1) != 1: if u.stride(-1) != 1:
u = u.contiguous() u = u.contiguous()
...@@ -344,28 +392,20 @@ def selective_scan_fn(u, ...@@ -344,28 +392,20 @@ def selective_scan_fn(u,
C = C.contiguous() C = C.contiguous()
if z is not None and z.stride(-1) != 1: if z is not None and z.stride(-1) != 1:
z = z.contiguous() z = z.contiguous()
if B.dim() == 3: if B.dim() == 3 and query_start_loc is None:
B = B.unsqueeze(1) B = B.unsqueeze(1)
if C.dim() == 3: if B.dim() == 2 and query_start_loc is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and query_start_loc is None:
C = C.unsqueeze(1) C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) if C.dim() == 2 and query_start_loc is not None:
x = torch.zeros(( C = C.unsqueeze(0)
u.shape[0],
u.shape[1], ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
n_chunks, query_start_loc, cache_indices, has_initial_state,
int(A.shape[1] * 2), ssm_states, pad_slot_id)
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, position_indices, x)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if z is None: if z is None:
return out if not return_last_state else (out, last_state) return delta # output written inplace to delta
else: else:
out_z = rest[0] return z # output written inplace to z
return out_z if not return_last_state else (out_z, last_state)
...@@ -11,6 +11,7 @@ from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput ...@@ -11,6 +11,7 @@ from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
class PoolingType(IntEnum): class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods.""" """Enumeration for different types of pooling methods."""
LAST = 0 LAST = 0
ALL = 1
class Pooler(nn.Module): class Pooler(nn.Module):
...@@ -43,6 +44,12 @@ class Pooler(nn.Module): ...@@ -43,6 +44,12 @@ class Pooler(nn.Module):
if self.pooling_type == PoolingType.LAST: if self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices] pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
offset = 0
pooled_data = []
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
else: else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}") raise ValueError(f"Invalid pooling type: {self.pooling_type}")
......
...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( ...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config) GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import ( from vllm.model_executor.layers.quantization.neuron_quant import (
...@@ -49,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -49,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"qqq": QQQConfig, "qqq": QQQConfig,
"experts_int8": ExpertsInt8Config, "experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig, "neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment