Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt,
TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
"""
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
to dispatch data processing according to the target model.
See also:
:ref:`input_processing_pipeline`
"""
__all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
]
...@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, ...@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
from typing_extensions import NotRequired from typing_extensions import NotRequired
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.sequence import MultiModalData from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -72,7 +72,7 @@ class TextPrompt(TypedDict): ...@@ -72,7 +72,7 @@ class TextPrompt(TypedDict):
prompt: str prompt: str
"""The input text to be tokenized before passing to the model.""" """The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"] multi_modal_data: NotRequired["MultiModalDataDict"]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
...@@ -85,7 +85,7 @@ class TokensPrompt(TypedDict): ...@@ -85,7 +85,7 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""A list of token IDs to pass to the model.""" """A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"] multi_modal_data: NotRequired["MultiModalDataDict"]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
...@@ -101,10 +101,9 @@ class TextTokensPrompt(TypedDict): ...@@ -101,10 +101,9 @@ class TextTokensPrompt(TypedDict):
"""The prompt text.""" """The prompt text."""
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the """The token IDs of the prompt."""
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"] multi_modal_data: NotRequired["MultiModalDataDict"]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
...@@ -125,6 +124,20 @@ PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] ...@@ -125,6 +124,20 @@ PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
"""
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt."""
prompt: NotRequired[Optional[str]] prompt: NotRequired[Optional[str]]
multi_modal_data: NotRequired[Optional["MultiModalData"]] """
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
import functools
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
TypeVar)
from torch import nn
from transformers import PretrainedConfig
from vllm.logger import init_logger
from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalDataDict
from vllm.sequence import SequenceData
logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig)
@dataclass(frozen=True)
class InputContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: "ModelConfig"
"""The configuration of the model."""
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.
Raises:
ValueError: If the model is not multimodal.
"""
multimodal_config = self.model_config.multimodal_config
if multimodal_config is None:
raise ValueError("No multimodal config found")
return multimodal_config
def get_hf_config(self, hf_config_type: Type[C]) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the model is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but "
f"found type: {type(hf_config)}")
return hf_config
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData",
Optional["MultiModalDataDict"]]]
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
"""Preprocess the inputs to the model."""
class InputRegistry:
"""
A registry to dispatch data processing
according to the target model.
"""
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
def _default_dummy_data_factory(
self,
ctx: InputContext,
seq_len: int,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
def register_dummy_data(self, factory: DummyDataFactory):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The resulting memory usage
should be an upper bound of what the model would use at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, model_config: "ModelConfig",
seq_len: int):
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
See also:
:ref:`enabling_multimodal_inputs`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
return dummy_factory(InputContext(model_config), seq_len)
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs:
"""The default input processor is a no-op."""
return inputs
def register_input_processor(self, processor: InputProcessor):
"""
Register an input processor to a model class.
The provided function is invoked on each input to the model. This
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
See also:
:ref:`input_processing_pipeline`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors_by_model_type:
logger.warning(
"Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors_by_model_type[model_cls] = processor
return model_cls
return wrapper
def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs:
"""
Apply an input processor to an instance of model inputs.
The model is identified by ``model_config``.
See also:
:ref:`input_processing_pipeline`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor = self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
return processor(InputContext(model_config), inputs)
def create_input_processor(self, model_config: "ModelConfig"):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
"""
return functools.partial(self.process_input, model_config)
...@@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank ...@@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
...@@ -90,11 +91,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -90,11 +91,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
def _mcp_apply(x, bias, layer): def _mcp_apply(x, bias, layer):
""" """
MergedColumnParallelLinearWithShardedLoRA and MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method. LoRa weight application method.
The main difference is the step by shard_size for lora_b which can The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA. MergedColumnParallelLinearWithShardedLoRA.
""" """
# expecting 2 for column parallel and 3 for qkv # expecting 2 for column parallel and 3 for qkv
...@@ -167,7 +168,7 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -167,7 +168,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
) )
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
""" """
Differs from QKVParallelLinearWithLora by slicing the Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also. LoRA A's also.
...@@ -175,6 +176,57 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): ...@@ -175,6 +176,57 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
""" """
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from MergedQKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]: ) -> List[Union[torch.Tensor, None]]:
......
...@@ -8,6 +8,7 @@ import torch.nn as nn ...@@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -134,15 +135,8 @@ def _apply_lora_packed_nslice( ...@@ -134,15 +135,8 @@ def _apply_lora_packed_nslice(
@dataclass @dataclass
class LoRAMapping: class LoRAMapping(AdapterMapping):
# Per every token in input_ids: pass
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
...@@ -641,6 +635,24 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -641,6 +635,24 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size) self.base_layer.head_size)
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -650,21 +662,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -650,21 +662,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
self.q_shard_id = tp_rank lora_b = self.slice_lora_b(lora_b)
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
...@@ -674,6 +673,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -674,6 +673,7 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
...@@ -1063,6 +1063,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1063,6 +1063,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def scale(self): def scale(self):
return self.base_layer.scale return self.base_layer.scale
@property
def soft_cap(self):
return self.base_layer.soft_cap
@property @property
def org_vocab_size(self): def org_vocab_size(self):
return self.base_layer.org_vocab_size return self.base_layer.org_vocab_size
...@@ -1162,11 +1166,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1162,11 +1166,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def _get_logits( def _get_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
embedding: torch.Tensor, lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None: if embedding_bias is not None:
logits += embedding_bias logits += embedding_bias
logits = tensor_model_parallel_gather(logits) logits = tensor_model_parallel_gather(logits)
......
from typing import List, Optional from typing import List, Optional
from typing import Sequence as GenericSequence
import torch import torch
import torch.types
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -63,7 +65,7 @@ class LoRALayerWeights: ...@@ -63,7 +65,7 @@ class LoRALayerWeights:
output_dim: int, output_dim: int,
rank: int, rank: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank], lora_a = torch.zeros([input_dim, rank],
...@@ -120,7 +122,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -120,7 +122,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod @classmethod
def pack( def pack(
cls, loras: List[Optional["LoRALayerWeights"]] cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights": ) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA. """Pack a list of LoRAs into a single LoRA.
......
...@@ -4,12 +4,17 @@ import math ...@@ -4,12 +4,17 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch import safetensors.torch
import torch import torch
from torch import nn from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA,
...@@ -18,7 +23,8 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ...@@ -18,7 +23,8 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache, is_pin_memory_available from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -152,7 +158,7 @@ def get_lora_id(): ...@@ -152,7 +158,7 @@ def get_lora_id():
return _GLOBAL_LORA_ID return _GLOBAL_LORA_ID
class LoRAModel: class LoRAModel(AdapterModel):
"""A LoRA fine-tuned model.""" """A LoRA fine-tuned model."""
def __init__( def __init__(
...@@ -302,25 +308,54 @@ class LoRAModel: ...@@ -302,25 +308,54 @@ class LoRAModel:
"new_embeddings.bin") "new_embeddings.bin")
with open(lora_config_path) as f: with open(lora_config_path) as f:
config = json.load(f) config = json.load(f)
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
# Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
if os.path.isfile(lora_tensor_path): if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path) tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
# Load tensors if there are only expected modules.
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path): elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path) tensors = torch.load(lora_bin_file_path)
else: else:
raise ValueError(f"{lora_dir} doesn't contain tensors") raise ValueError(f"{lora_dir} doesn't contain tensors")
...@@ -358,12 +393,12 @@ class LoRAModel: ...@@ -358,12 +393,12 @@ class LoRAModel:
) )
class LoRAModelManager: class LoRAModelManager(AdapterModelManager):
"""A manager that manages multiple LoRA-fine-tuned models.""" """A manager that manages multiple LoRA-fine-tuned models."""
def __init__( def __init__(
self, self,
model: nn.Module, model: SupportsLoRA,
max_num_seqs: int, max_num_seqs: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
...@@ -410,8 +445,7 @@ class LoRAModelManager: ...@@ -410,8 +445,7 @@ class LoRAModelManager:
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices # embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4 self.indices_len: List[Optional[int]] = [None] * 4
super().__init__(model)
self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy( self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules) self.model.supported_lora_modules)
...@@ -423,12 +457,11 @@ class LoRAModelManager: ...@@ -423,12 +457,11 @@ class LoRAModelManager:
self.model.packed_modules_mapping) self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
self.adapter_type = 'LoRa'
@property @property
def capacity(self) -> int: def capacity(self) -> int:
...@@ -438,15 +471,16 @@ class LoRAModelManager: ...@@ -438,15 +471,16 @@ class LoRAModelManager:
def lora_slots(self) -> int: def lora_slots(self) -> int:
return self.lora_config.max_loras return self.lora_config.max_loras
def __len__(self) -> int: @property
return len(self._registered_loras) def adapter_slots(self) -> int:
return self.lora_slots
def activate_lora( def activate_adapter(
self, self,
lora_id: int, lora_id: int,
) -> bool: ) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass.""" """Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras: if lora_id in self._active_adapters:
return False return False
first_free_slot = next( first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
...@@ -454,8 +488,8 @@ class LoRAModelManager: ...@@ -454,8 +488,8 @@ class LoRAModelManager:
if first_free_slot is None: if first_free_slot is None:
raise ValueError("No free lora slots") raise ValueError("No free lora slots")
index, _ = first_free_slot index, _ = first_free_slot
self._active_loras[lora_id] = None self._active_adapters[lora_id] = None
lora_model = self._registered_loras[lora_id] lora_model = self._registered_adapters[lora_id]
logger.debug("Activating LoRA. int id: %d, slot index: %d", logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index) lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id self.lora_index_to_id[index] = lora_model.id
...@@ -469,21 +503,13 @@ class LoRAModelManager: ...@@ -469,21 +503,13 @@ class LoRAModelManager:
module.reset_lora(index) module.reset_lora(index)
return True return True
def _deactivate_lora(self, lora_id: int): def _deactivate_adapter(self, lora_id: int):
try: try:
index = self.lora_index_to_id.index(lora_id) index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None self.lora_index_to_id[index] = None
except ValueError: except ValueError:
pass pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _set_long_lora_context(self, lora: LoRAModel): def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None: if self.long_lora_context is None:
return return
...@@ -499,34 +525,19 @@ class LoRAModelManager: ...@@ -499,34 +525,19 @@ class LoRAModelManager:
if offsets: if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_lora(self, lora: LoRAModel): def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora) self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora self._registered_adapters[lora.id] = lora
self._set_long_lora_context(lora) self._set_long_lora_context(lora)
def add_lora(self, lora: LoRAModel) -> bool: def pin_adapter(self, lora_id: int) -> bool:
"""Add a LoRAModel to the manager CPU cache.""" """Pin a LoRAModel in the manager cache."""
logger.debug( raise NotImplementedError(
"Adding lora. Model id: %d, " "Pinning is not supported in LoRAModelManager."
"int id: %d, " "Use LRUCacheLoRAModelManager for pinning") # type: ignore
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized # TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded, (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_offsets_tensor, embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id, indices_len) = convert_mapping(mapping, self.lora_index_to_id,
...@@ -548,23 +559,11 @@ class LoRAModelManager: ...@@ -548,23 +559,11 @@ class LoRAModelManager:
# Maintain the reference # Maintain the reference
self.indices_len[:] = indices_len self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: def remove_all_adapters(self):
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
self._registered_loras.clear() self._registered_adapters.clear()
self.lora_index_to_id = [None] * self.lora_slots self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear() self._active_adapters.clear()
def _create_lora_modules(self): def _create_lora_modules(self):
for module_name, module in self.model.named_modules( for module_name, module in self.model.named_modules(
...@@ -708,18 +707,39 @@ class LoRAModelManager: ...@@ -708,18 +707,39 @@ class LoRAModelManager:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras) replacement_loras)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", adapter.id, adapter.id,
adapter.scaling_factor)
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class LoRALRUCache(LRUCache[LoRAModel]):
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]): bool]):
super().__init__(capacity) super().__init__(capacity, deactivate_lora_fn)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: int, value: LoRAModel):
logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
class LRUCacheLoRAModelManager(LoRAModelManager): class LRUCacheLoRAModelManager(LoRAModelManager):
...@@ -735,48 +755,68 @@ class LRUCacheLoRAModelManager(LoRAModelManager): ...@@ -735,48 +755,68 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
): ):
super().__init__(model, max_num_seqs, max_num_batched_tokens, super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config) vocab_size, lora_config)
self._registered_loras: LoRALRUCache = LoRALRUCache( self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora) self.capacity, self.deactivate_adapter)
self._active_loras: LoRALRUCache = LoRALRUCache( self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora) self.lora_slots, self._deactivate_adapter)
def list_loras(self) -> Dict[int, LoRAModel]: def list_adapters(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels.""" """List all registered LoRAModels."""
return dict(self._registered_loras.cache) return dict(self._registered_adapters.cache)
def add_lora(self, lora: LoRAModel) -> bool: def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager.""" """Add a LoRAModel to the manager."""
logger.debug( logger.debug(
"Adding lora. Model id: %d, " "Adding lora. Model id: %d, "
"int id: %d, " "int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor) "scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras: if lora.id not in self._registered_adapters:
self._add_lora(lora) self._add_adapter(lora)
was_added = True was_added = True
else: else:
# We always touch to update the LRU cache order # We always touch to update the LRU cache order
self._registered_loras.touch(lora.id) self._registered_adapters.touch(lora.id)
was_added = False was_added = False
return was_added return was_added
def activate_lora( def activate_adapter(
self, self,
lora_id: int, lora_id: int,
) -> bool: ) -> bool:
if lora_id not in self._active_loras and len( if lora_id not in self._active_adapters and len(
self._active_loras) >= self.lora_slots: self._active_adapters) >= self.lora_slots:
self._active_loras.remove_oldest() self._active_adapters.remove_oldest()
result = super().activate_lora(lora_id) result = super().activate_adapter(lora_id)
# We always touch to update the LRU cache order # We always touch to update the LRU cache order
self._active_loras.touch(lora_id) self._active_adapters.touch(lora_id)
return result return result
def remove_oldest_lora(self) -> bool: def remove_oldest_adapter(self) -> bool:
if len(self._registered_loras) > 0: if len(self._registered_adapters) > 0:
self._registered_loras.remove_oldest() self._registered_adapters.remove_oldest()
return True return True
return False return False
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True
def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_adapters.pin(lora_id)
except ValueError as err:
raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err
def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_adapters:
# move lora to gpu if not already active
self.activate_adapter(lora_id)
self._active_adapters.pin(lora_id)
def create_lora_manager( def create_lora_manager(
model: nn.Module, model: nn.Module,
......
...@@ -5,13 +5,14 @@ from typing import Optional ...@@ -5,13 +5,14 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform
def _check_punica_support(): def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return return
if torch.cuda.get_device_capability() < (8, 0): if current_platform.get_device_capability() < (8, 0):
raise ImportError( raise ImportError(
"punica LoRA kernels require compute capability >= 8.0") "punica LoRA kernels require compute capability >= 8.0")
else: else:
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from vllm.adapter_commons.request import AdapterRequest
@dataclass @dataclass
class LoRARequest: class LoRARequest(AdapterRequest):
""" """
Request for a LoRA adapter. Request for a LoRA adapter.
Note that this class should be be used internally. For online Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters. accessing unauthorized LoRA adapters.
...@@ -20,15 +22,16 @@ class LoRARequest: ...@@ -20,15 +22,16 @@ class LoRARequest:
lora_int_id: int lora_int_id: int
lora_local_path: str lora_local_path: str
long_lora_max_len: Optional[int] = None long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__
def __post_init__(self): @property
if self.lora_int_id < 1: def adapter_id(self):
raise ValueError( return self.lora_int_id
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool: @property
return isinstance( def name(self):
value, LoRARequest) and self.lora_int_id == value.lora_int_id return self.lora_name
def __hash__(self) -> int: @property
return self.lora_int_id def local_path(self):
return self.lora_local_path
...@@ -8,7 +8,8 @@ from vllm.logger import init_logger ...@@ -8,7 +8,8 @@ from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import ( from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below # being imported for _all_lora_classes below
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -35,6 +36,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { ...@@ -35,6 +36,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
......
from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
import torch import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import (LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -14,79 +17,13 @@ from vllm.lora.request import LoRARequest ...@@ -14,79 +17,13 @@ from vllm.lora.request import LoRARequest
logger = init_logger(__name__) logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC): class WorkerLoRAManager(AbstractWorkerManager):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
max_position_embeddings: Optional[int] = None):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self):
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(AbstractWorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side. """WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded.""" loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager _manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__( def __init__(
self, self,
...@@ -103,16 +40,23 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -103,16 +40,23 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.lora_config = lora_config
self.max_position_embeddings = max_position_embeddings
super().__init__(device)
# Lazily initialized by create_lora_manager. # Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager self._adapter_manager: LoRAModelManager
super().__init__(
max_num_seqs, @contextmanager
max_num_batched_tokens, def dummy_lora_cache(self):
vocab_size, """Use this context manager to reuse the dummy lora model
lora_config, to avoid creating it repeatedly."""
device, self._cached_dummy_lora = None
max_position_embeddings=max_position_embeddings, yield
) self._cached_dummy_lora = False
@property @property
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
...@@ -128,44 +72,17 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -128,44 +72,17 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._manager_cls,
) )
self._lora_manager = lora_manager self._adapter_manager = lora_manager
return lora_manager.model return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest], def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try: try:
model = self._lora_manager.model model = self._adapter_manager.model
supported_lora_modules = model.supported_lora_modules supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend( expected_lora_modules.extend(
...@@ -198,34 +115,45 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -198,34 +115,45 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
return lora return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): if lora_request.lora_int_id in self.list_adapters():
return False return False
if isinstance(self._cached_dummy_lora, LoRAModel): if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone( dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id) lora_request.lora_int_id)
else: else:
dummy_lora = self._lora_manager.create_dummy_lora( dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules) lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None: if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora) return self._adapter_manager.add_adapter(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def pin_adapter(self, adapter_id: int) -> bool:
if lora_request.lora_int_id in self.list_loras(): return self._adapter_manager.pin_adapter(adapter_id)
return False
lora = self._load_lora(lora_request) def set_active_adapters(self, requests: Set[Any],
loaded = self._lora_manager.add_lora(lora) mapping: Optional[Any]) -> None:
self._lora_manager.activate_lora(lora.id) set_active_adapters_worker(requests, mapping, self._apply_adapters,
return loaded self._adapter_manager.set_adapter_mapping)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def remove_lora(self, lora_id: int) -> bool: def remove_adapter(self, adapter_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id) return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_loras(self): def remove_all_adapters(self):
self._lora_manager.remove_all_loras() self._adapter_manager.remove_all_adapters()
def list_loras(self) -> Set[int]: def list_adapters(self) -> Set[int]:
return set(self._lora_manager.list_loras()) return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerLoRAManager(WorkerLoRAManager): class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
...@@ -235,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -235,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
(unless they are already loaded) and least recently used LoRAs will (unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity.""" be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[ _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager( def create_lora_manager(
self, self,
...@@ -244,40 +171,41 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -244,40 +171,41 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
) -> Any: ) -> Any:
lora_manager = create_lora_manager( lora_manager = create_lora_manager(
model, model,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._manager_cls,
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
lora_config=self.lora_config, lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
) )
self._lora_manager = lora_manager self._adapter_manager = lora_manager
return lora_manager.model return lora_manager.model
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request for lora_request in lora_requests if lora_request
} }
if len(loras_map) > self._lora_manager.lora_slots: if len(loras_map) > self._adapter_manager.lora_slots:
raise RuntimeError( raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater " f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots " "than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).") f"({self._adapter_manager.lora_slots}).")
for lora in loras_map.values(): for lora in loras_map.values():
self.add_lora(lora) self.add_adapter(lora)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_adapter(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras(): if lora_request.lora_int_id not in self.list_adapters():
# Remove before we load the new lora to save memory # Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity: if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) assert isinstance(self._adapter_manager,
self._lora_manager.remove_oldest_lora() LRUCacheLoRAModelManager)
lora = self._load_lora(lora_request) self._adapter_manager.remove_oldest_adapter()
loaded = self._lora_manager.add_lora(lora) lora = self._load_adapter(lora_request)
loaded = self._adapter_manager.add_adapter(lora)
else: else:
# If the lora is already loaded, just touch it to # If the lora is already loaded, just touch it to
# update its position in the caches # update its position in the caches
loaded = self._lora_manager.get_lora( loaded = self._adapter_manager.get_adapter(
lora_request.lora_int_id) is not None lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id) self._adapter_manager.activate_adapter(lora_request.lora_int_id)
return loaded return loaded
import torch.nn as nn import torch.nn as nn
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -29,9 +29,7 @@ class CustomOp(nn.Module): ...@@ -29,9 +29,7 @@ class CustomOp(nn.Module):
return self.forward_cuda(*args, **kwargs) return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs): def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops. raise NotImplementedError
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
def forward_cpu(self, *args, **kwargs): def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops. # By default, we assume that CPU ops are compatible with CUDA ops.
...@@ -56,5 +54,9 @@ class CustomOp(nn.Module): ...@@ -56,5 +54,9 @@ class CustomOp(nn.Module):
return self.forward_hip return self.forward_hip
elif is_cpu(): elif is_cpu():
return self.forward_cpu return self.forward_cpu
elif is_tpu():
return self.forward_tpu
elif is_xpu():
return self.forward_xpu
else: else:
return self.forward_cuda return self.forward_cuda
...@@ -21,6 +21,7 @@ from functools import lru_cache ...@@ -21,6 +21,7 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
...@@ -67,7 +68,7 @@ class BaseLogitsProcessor: ...@@ -67,7 +68,7 @@ class BaseLogitsProcessor:
class RegexLogitsProcessor(BaseLogitsProcessor): class RegexLogitsProcessor(BaseLogitsProcessor):
@classmethod @classmethod
@lru_cache(maxsize=32) @cache()
def _get_guide(cls, regex_string: str, def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
...@@ -126,7 +127,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -126,7 +127,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
class CFGLogitsProcessor(BaseLogitsProcessor): class CFGLogitsProcessor(BaseLogitsProcessor):
@classmethod @classmethod
@lru_cache(maxsize=32) @cache()
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer) return CFGGuide(cfg, tokenizer)
......
...@@ -37,6 +37,15 @@ class SiluAndMul(CustomOp): ...@@ -37,6 +37,15 @@ class SiluAndMul(CustomOp):
ops.silu_and_mul(out, x) ops.silu_and_mul(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
"""An activation function for GeGLU. """An activation function for GeGLU.
...@@ -71,6 +80,18 @@ class GeluAndMul(CustomOp): ...@@ -71,6 +80,18 @@ class GeluAndMul(CustomOp):
ops.gelu_tanh_and_mul(out, x) ops.gelu_tanh_and_mul(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}' return f'approximate={repr(self.approximate)}'
...@@ -90,6 +111,13 @@ class NewGELU(CustomOp): ...@@ -90,6 +111,13 @@ class NewGELU(CustomOp):
ops.gelu_new(out, x) ops.gelu_new(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class FastGELU(CustomOp): class FastGELU(CustomOp):
...@@ -105,6 +133,31 @@ class FastGELU(CustomOp): ...@@ -105,6 +133,31 @@ class FastGELU(CustomOp):
ops.gelu_fast(out, x) ops.gelu_fast(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out
# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters. """An activation function with post-scale parameters.
...@@ -154,6 +207,7 @@ _ACTIVATION_REGISTRY = { ...@@ -154,6 +207,7 @@ _ACTIVATION_REGISTRY = {
"gelu_new": NewGELU(), "gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(), "relu": nn.ReLU(),
"quick_gelu": QuickGELU(),
} }
......
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
__all__ = [ __all__ = [
"fused_moe", "fused_moe",
...@@ -7,4 +9,6 @@ __all__ = [ ...@@ -7,4 +9,6 @@ __all__ = [
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"grouped_topk", "grouped_topk",
"FusedMoE",
"FusedMoEMethodBase",
] ]
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
{ {
"1": { "1": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 2,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
}, },
"2": { "2": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 2,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"4": { "4": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64, "GROUP_SIZE_M": 1,
"num_stages": 1 "num_warps": 2,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"8": { "8": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 1,
"num_stages": 1 "num_warps": 1,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"16": { "16": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 1,
"num_stages": 1 "num_warps": 4,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"24": { "24": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64, "GROUP_SIZE_M": 1,
"num_stages": 1 "num_warps": 1,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"32": { "32": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 4,
"num_stages": 1 "num_warps": 2,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
}, },
"48": { "48": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 4,
"num_stages": 0 "num_warps": 2,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"64": { "64": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 4,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"96": { "96": {
"BLOCK_SIZE_M": 32, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 4,
"num_stages": 0 "num_warps": 4,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"128": { "128": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 4,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"256": { "256": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 4,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
}, },
"512": { "512": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8, "GROUP_SIZE_M": 4,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"1024": { "1024": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 32,
"kpack": 2
}, },
"1536": { "1536": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"2048": { "2048": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}, },
"3072": { "3072": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
}, },
"4096": { "4096": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_stages": 0 "num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment