Unverified Commit 0a0a1a19 authored by Kyuyeun Kim's avatar Kyuyeun Kim Committed by GitHub
Browse files

Add ability to replace oot ops when using lora (#37181)


Signed-off-by: default avatarKyuyeun Kim <kyuyeunk@google.com>
parent 6c1cfbad
...@@ -9,6 +9,7 @@ from transformers import PretrainedConfig ...@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
if type(source_layer) is ColumnParallelLinear: if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
return True return True
if type(source_layer) is MergedColumnParallelLinear: if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
if len(packed_modules_list) != 1: if len(packed_modules_list) != 1:
return False return False
# Exclude layers with 3+ output sizes - those are handled by # Exclude layers with 3+ output sizes - those are handled by
...@@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA( ...@@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
) -> bool: ) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices # Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA) # (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not MergedColumnParallelLinear: if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
return False return False
# If packed_modules_list has 3+ items, use this class # If packed_modules_list has 3+ items, use this class
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from .base_linear import BaseLinearLayerWithLoRA from .base_linear import BaseLinearLayerWithLoRA
...@@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is ReplicatedLinear return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)
def slice_lora_a( def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None] self, lora_a: torch.Tensor | list[torch.Tensor | None]
......
...@@ -11,6 +11,7 @@ from vllm.distributed import ( ...@@ -11,6 +11,7 @@ from vllm.distributed import (
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is RowParallelLinear return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)
# The following layer is based on the tensor parallelism strategy given in # The following layer is based on the tensor parallelism strategy given in
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding)
@property @property
def weight(self): def weight(self):
......
...@@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} ...@@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
def get_oot_class_by_name(class_name: str) -> type | None: def maybe_get_oot_by_class(class_type: type) -> type:
class_name = class_type.__name__
if class_name in op_registry_oot: if class_name in op_registry_oot:
return op_registry_oot[class_name] return op_registry_oot[class_name]
return None return class_type
class PluggableLayer(nn.Module): class PluggableLayer(nn.Module):
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
...@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp): ...@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: np.ndarray, cu_seqlens: np.ndarray,
device: torch.device, device: torch.device,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined] return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
if attn_backend != AttentionBackendEnum.FLASHINFER: if attn_backend != AttentionBackendEnum.FLASHINFER:
...@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp): ...@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp):
tp_size: int, tp_size: int,
device: torch.device, device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined] return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device attn_backend, cu_seqlens, hidden_size, tp_size, device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment