Unverified Commit 9b5b39b6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `vllm/lora` (#18128)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9ccc6ded
...@@ -78,7 +78,6 @@ exclude = [ ...@@ -78,7 +78,6 @@ exclude = [
"vllm/distributed/**/*.py" = ["UP006", "UP035"] "vllm/distributed/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# pylint: disable=unused-argument # pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -118,7 +118,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -118,7 +118,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
...@@ -141,8 +141,8 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -141,8 +141,8 @@ class MergedColumnParallelLinearWithShardedLoRA(
""" """
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: list[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]: ) -> list[Union[torch.Tensor, None]]:
#NOTE: lora_a contains 2 subloras, and each sublora could be None. #NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2] output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size output_start_idx = self.tp_rank * output_shard_size
...@@ -165,7 +165,7 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -165,7 +165,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
...@@ -201,7 +201,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): ...@@ -201,7 +201,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
...@@ -222,8 +222,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): ...@@ -222,8 +222,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
""" """
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: list[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]: ) -> list[Union[torch.Tensor, None]]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None. # NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
...@@ -248,7 +248,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): ...@@ -248,7 +248,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
...@@ -281,7 +281,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -281,7 +281,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None: if bias is None:
return bias return bias
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked) self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2] shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
...@@ -341,7 +341,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -341,7 +341,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping): ...@@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping):
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def slice_lora_a( def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism.""" """Slice lora a if splitting for tensor parallelism."""
... ...
def slice_lora_b( def slice_lora_b(
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism.""" """Slice lora b if splitting with tensor parallelism."""
... ...
...@@ -128,7 +128,7 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -128,7 +128,7 @@ class BaseLayerWithLoRA(nn.Module):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
...@@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None: def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]] self.embeddings_slice: Optional[tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor] self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights( def create_lora_weights(
...@@ -279,7 +279,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -279,7 +279,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is VocabParallelEmbedding
...@@ -296,9 +296,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -296,9 +296,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.base_layer = base_layer self.base_layer = base_layer
self.input_size = self.base_layer.input_size self.input_size = self.base_layer.input_size
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: Tuple[int, ...] self.output_slices: tuple[int, ...]
self.tp_size: int self.tp_size: int
self.output_size: int self.output_size: int
self.n_slices: int self.n_slices: int
...@@ -365,7 +365,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -365,7 +365,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0
if self.lora_config.bias_enabled: if self.lora_config.bias_enabled:
# Make mypy happy # Make mypy happy
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked) self.lora_bias_stacked)
self.lora_bias_stacked[s_index][index] = 0 self.lora_bias_stacked[s_index][index] = 0
...@@ -399,7 +399,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -399,7 +399,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if lora_bias is not None: if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked) self.lora_bias_stacked)
assert len(self.lora_bias_stacked) assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
...@@ -497,7 +497,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -497,7 +497,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return type(source_layer) is ReplicatedLinear return type(source_layer) is ReplicatedLinear
...@@ -597,7 +597,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -597,7 +597,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return type(source_layer) is ColumnParallelLinear or ( return type(source_layer) is ColumnParallelLinear or (
...@@ -674,13 +674,13 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -674,13 +674,13 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) for output_size in self.output_slices) ) for output_size in self.output_slices)
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: list[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]: ) -> list[Union[torch.Tensor, None]]:
return lora_a return lora_a
def slice_lora_b( def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]] self, lora_b: list[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]: ) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate( for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)): zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None: if (lora_b_i := lora_b[i]) is not None:
...@@ -689,8 +689,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -689,8 +689,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return lora_b return lora_b
def slice_bias( def slice_bias(
self, bias: List[Union[torch.Tensor, self, bias: list[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]: None]]) -> list[Union[torch.Tensor, None]]:
for i, (shard_id, shard_size) in enumerate( for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)): zip(self.output_ids, self.output_slices)):
if (bias_i := bias[i]) is not None: if (bias_i := bias[i]) is not None:
...@@ -725,7 +725,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -725,7 +725,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_b_i.T, non_blocking=True) lora_b_i.T, non_blocking=True)
if lora_bias is not None: if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked) self.lora_bias_stacked)
for i in range(self.n_slices): for i in range(self.n_slices):
if (lora_bias_i := lora_bias[i]) is not None: if (lora_bias_i := lora_bias[i]) is not None:
...@@ -740,7 +740,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -740,7 +740,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return (type(source_layer) is MergedColumnParallelLinear return (type(source_layer) is MergedColumnParallelLinear
...@@ -809,7 +809,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -809,7 +809,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len( return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 1 packed_modules_list) == 1
...@@ -869,7 +869,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): ...@@ -869,7 +869,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return (type(source_layer) is QKVParallelLinear return (type(source_layer) is QKVParallelLinear
...@@ -923,7 +923,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -923,7 +923,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
- output - output
- bias - bias
""" """
# Set up backprop all-reduce. # set up backprop all-reduce.
if self.base_layer.input_is_parallel: if self.base_layer.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
...@@ -958,7 +958,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -958,7 +958,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
return type(source_layer) is RowParallelLinear return type(source_layer) is RowParallelLinear
...@@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LogitsProcessor, hidden_size: int, def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
dtype: torch.dtype, device: torch.device, dtype: torch.dtype, device: torch.device,
sharded_to_full_mapping: Optional[List[int]]) -> None: sharded_to_full_mapping: Optional[list[int]]) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -1189,7 +1189,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1189,7 +1189,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
...@@ -1256,7 +1256,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -1256,7 +1256,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.base_layer( return self.base_layer(
positions, positions,
query, query,
...@@ -1265,7 +1265,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -1265,7 +1265,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
) )
@property @property
def scaling_factor_to_offset(self) -> Dict[float, int]: def scaling_factor_to_offset(self) -> dict[float, int]:
return self.base_layer.scaling_factor_to_offset return self.base_layer.scaling_factor_to_offset
@classmethod @classmethod
...@@ -1273,7 +1273,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -1273,7 +1273,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from collections.abc import Sequence as GenericSequence
from typing import Sequence as GenericSequence from typing import Optional
import torch import torch
import torch.types import torch.types
...@@ -125,11 +125,11 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -125,11 +125,11 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self, self,
module_name: str, module_name: str,
rank: int, rank: int,
lora_alphas: List[Optional[int]], lora_alphas: list[Optional[int]],
lora_a: List[Optional[torch.Tensor]], lora_a: list[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]], lora_b: list[Optional[torch.Tensor]],
bias: Optional[List[Optional[torch.Tensor]]] = None, bias: Optional[list[Optional[torch.Tensor]]] = None,
scaling: Optional[List[float]] = None, scaling: Optional[list[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
module_name=module_name, module_name=module_name,
......
...@@ -4,9 +4,9 @@ import copy ...@@ -4,9 +4,9 @@ import copy
import math import math
import os import os
import re import re
from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, from typing import Any, Callable, Optional, Union
Union)
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -44,12 +44,12 @@ _GLOBAL_LORA_ID = 0 ...@@ -44,12 +44,12 @@ _GLOBAL_LORA_ID = 0
class LongContextLoRAContext: class LongContextLoRAContext:
"""Context for lora adapters that support long context.""" """Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models. # The scaling factors to support long context lora fine tuned models.
scaling_factors: List[float] scaling_factors: list[float]
# dimension to apply rotary embedding. # dimension to apply rotary embedding.
rot_dim: int rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded. # offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified. # This value is dynamically modified.
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
def get_lora_id(): def get_lora_id():
...@@ -65,7 +65,7 @@ class LoRAModel(AdapterModel): ...@@ -65,7 +65,7 @@ class LoRAModel(AdapterModel):
self, self,
lora_model_id: int, lora_model_id: int,
rank: int, rank: int,
loras: Dict[str, LoRALayerWeights], loras: dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None, scaling_factor: Optional[float] = None,
) -> None: ) -> None:
""" """
...@@ -84,7 +84,7 @@ class LoRAModel(AdapterModel): ...@@ -84,7 +84,7 @@ class LoRAModel(AdapterModel):
lora_model_id lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}" > 0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras self.loras: dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel": def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids. """Return a copy of the object with different ids.
...@@ -113,19 +113,19 @@ class LoRAModel(AdapterModel): ...@@ -113,19 +113,19 @@ class LoRAModel(AdapterModel):
def from_lora_tensors( def from_lora_tensors(
cls, cls,
lora_model_id: int, lora_model_id: int,
tensors: Dict[str, torch.Tensor], tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None, embeddings: Optional[dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None, embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None, embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None, weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel": ) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors.""" """Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {} loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items(): for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name, weights_mapper) tensor_name, weights_mapper)
...@@ -187,15 +187,15 @@ class LoRAModel(AdapterModel): ...@@ -187,15 +187,15 @@ class LoRAModel(AdapterModel):
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: List[str], expected_lora_modules: list[str],
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
*, *,
lora_model_id: Optional[int] = None, lora_model_id: Optional[int] = None,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None, target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None, embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None, embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None, weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel": ) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint. """Create a LoRAModel from a local checkpoint.
...@@ -220,9 +220,9 @@ class LoRAModel(AdapterModel): ...@@ -220,9 +220,9 @@ class LoRAModel(AdapterModel):
new_embeddings_bin_file_path = os.path.join(lora_dir, new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin") "new_embeddings.bin")
unexpected_modules: List[Union[list[str], str]] unexpected_modules: list[Union[list[str], str]]
if os.path.isfile(lora_tensor_path): if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {} tensors: dict[str, torch.Tensor] = {}
# Find unexpected modules. # Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules. # Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist # in peft if you have target_modules A, B, C and C does not exist
...@@ -329,7 +329,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -329,7 +329,7 @@ class LoRAModelManager(AdapterModelManager):
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper( self.punica_wrapper = get_punica_wrapper(
...@@ -339,7 +339,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -339,7 +339,7 @@ class LoRAModelManager(AdapterModelManager):
max_loras=self.lora_config.max_loras) max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it. # Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora. # Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: dict[float, int] = {}
super().__init__(model) super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model) self.supported_lora_modules = get_supported_lora_modules(self.model)
...@@ -358,9 +358,9 @@ class LoRAModelManager(AdapterModelManager): ...@@ -358,9 +358,9 @@ class LoRAModelManager(AdapterModelManager):
# text modules (e.g. ChatGLM) # text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")) and hasattr(self.model, "get_mm_mapping"))
self.is_pooling_model = is_pooling_model(self.model) self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: dict[str, list[str]] = {}
self.modules: Dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a set for compatibility with LRUCache.
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
...@@ -530,7 +530,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -530,7 +530,7 @@ class LoRAModelManager(AdapterModelManager):
lora_id: int, lora_id: int,
rank: int, rank: int,
scaling_factor: Optional[float], scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor) model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
...@@ -578,7 +578,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -578,7 +578,7 @@ class LoRAModelManager(AdapterModelManager):
else: else:
parts = module_name.split(".") parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]] replacements = self.packed_modules_mapping[parts[-1]]
subloras: List[Optional[LoRALayerWeights]] = [] subloras: list[Optional[LoRALayerWeights]] = []
for i, r in enumerate(replacements): for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r, module_name + "." + r,
...@@ -630,8 +630,8 @@ class LoRAModelManager(AdapterModelManager): ...@@ -630,8 +630,8 @@ class LoRAModelManager(AdapterModelManager):
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items(): for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = [] replacement_loras: list[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set() replaced_module: set[str] = set()
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = self._get_lora_layer_weights(lora_model, r) lora = self._get_lora_layer_weights(lora_model, r)
...@@ -694,7 +694,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -694,7 +694,7 @@ class LoRAModelManager(AdapterModelManager):
return remove_adapter(adapter_id, self._registered_adapters, return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter) self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]: def list_adapters(self) -> dict[int, Any]:
return list_adapters(self._registered_adapters) return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]: def get_adapter(self, adapter_id: int) -> Optional[Any]:
...@@ -721,7 +721,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager): ...@@ -721,7 +721,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
self._active_adapters: LoRALRUCache = LoRALRUCache( self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter) self.lora_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, LoRAModel]: def list_adapters(self) -> dict[int, LoRAModel]:
"""List all registered LoRAModels.""" """List all registered LoRAModels."""
return dict(self._registered_adapters.cache) return dict(self._registered_adapters.cache)
...@@ -786,7 +786,7 @@ def create_lora_manager( ...@@ -786,7 +786,7 @@ def create_lora_manager(
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
device: torch.device, device: torch.device,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager: **kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model.""" """Create a LoRA adapter for a given model."""
if not hasattr(model, "packed_modules_mapping"): if not hasattr(model, "packed_modules_mapping"):
......
...@@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import List
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -127,7 +125,7 @@ def _lora_expand_kernel( ...@@ -127,7 +125,7 @@ def _lora_expand_kernel(
@torch.inference_mode() @torch.inference_mode()
def _lora_expand( def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: List[ lora_b_weights: list[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank] torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch. output_tensor: torch.
Tensor, # shape [num_tokens, hidden_size * num_slices] Tensor, # shape [num_tokens, hidden_size * num_slices]
...@@ -143,7 +141,7 @@ def _lora_expand( ...@@ -143,7 +141,7 @@ def _lora_expand(
""" """
Args: Args:
inputs (torch.Tensor): input tensor inputs (torch.Tensor): input tensor
lora_b_weights (List[torch.Tensor]): lora'b weight lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that to the lora-id related to that token. A value of -1 indicates that
...@@ -264,7 +262,7 @@ def _lora_expand( ...@@ -264,7 +262,7 @@ def _lora_expand(
def _lora_expand_fake( def _lora_expand_fake(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor], lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor, token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor,
......
...@@ -4,7 +4,7 @@ LoRA kernels metadata preparation utilities. ...@@ -4,7 +4,7 @@ LoRA kernels metadata preparation utilities.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Union from typing import Union
import torch import torch
...@@ -125,7 +125,7 @@ class LoRAKernelMeta: ...@@ -125,7 +125,7 @@ class LoRAKernelMeta:
def meta_args( def meta_args(
self, token_nums: int self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]: torch.Tensor, torch.Tensor]:
""" """
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
......
...@@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import List
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -98,7 +96,7 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, ...@@ -98,7 +96,7 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
@torch.inference_mode() @torch.inference_mode()
def _lora_shrink( def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size] inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: List[ lora_a_weights: list[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size] torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens] token_lora_mapping: torch.Tensor, # shape [num_tokens]
...@@ -112,7 +110,7 @@ def _lora_shrink( ...@@ -112,7 +110,7 @@ def _lora_shrink(
""" """
Args: Args:
inputs (torch.Tensor): Input tensor inputs (torch.Tensor): Input tensor
lora_a_weights (List[torch.Tensor]): LoRA weights lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that to the lora-id related to that token. A value of -1 indicates that
...@@ -219,7 +217,7 @@ def _lora_shrink( ...@@ -219,7 +217,7 @@ def _lora_shrink(
def _lora_shrink_fake( def _lora_shrink_fake(
inputs: torch.Tensor, inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor], lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor, token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Tuple
import torch import torch
_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
""" """
`_LORA_A_PTR_DICT` collects the required information during `profile_run`, `_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT. After this, it remains constant and subsequent usage is through LUT.
...@@ -53,7 +51,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): ...@@ -53,7 +51,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
return _LORA_A_PTR_DICT.get(key) return _LORA_A_PTR_DICT.get(key)
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
device: torch.device): device: torch.device):
""" """
`_LORA_B_PTR_DICT` collects the required information during `profile_run`, `_LORA_B_PTR_DICT` collects the required information during `profile_run`,
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
import math import math
import os import os
from dataclasses import MISSING, dataclass, field, fields from dataclasses import MISSING, dataclass, field, fields
from typing import List, Literal, Optional, Union from typing import Literal, Optional, Union
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -40,7 +40,7 @@ class PEFTHelper: ...@@ -40,7 +40,7 @@ class PEFTHelper:
vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None) vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self) -> List[str]: def _validate_features(self) -> list[str]:
""" """
Check if there are any unsupported LoRA features. Check if there are any unsupported LoRA features.
""" """
......
...@@ -7,7 +7,7 @@ https://arxiv.org/abs/2310.18547 ...@@ -7,7 +7,7 @@ https://arxiv.org/abs/2310.18547
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
...@@ -28,7 +28,7 @@ class PunicaWrapperABC(ABC): ...@@ -28,7 +28,7 @@ class PunicaWrapperABC(ABC):
def update_metadata( def update_metadata(
self, self,
mapping: "LoRAMapping", mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]], lora_index_to_id: list[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
...@@ -43,9 +43,9 @@ class PunicaWrapperABC(ABC): ...@@ -43,9 +43,9 @@ class PunicaWrapperABC(ABC):
@abstractmethod @abstractmethod
def add_shrink( def add_shrink(
self, self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, scale: float,
**kwargs, **kwargs,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
...@@ -59,10 +59,10 @@ class PunicaWrapperABC(ABC): ...@@ -59,10 +59,10 @@ class PunicaWrapperABC(ABC):
def add_expand( def add_expand(
self, self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
...@@ -91,13 +91,13 @@ class PunicaWrapperABC(ABC): ...@@ -91,13 +91,13 @@ class PunicaWrapperABC(ABC):
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None, buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> Optional[torch.Tensor]: **kwargs) -> Optional[torch.Tensor]:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
...@@ -150,7 +150,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -150,7 +150,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
# 5 is the number of indices tensors. # 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices # embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5 self.indices_len: list[Optional[int]] = [None] * 5
# these attributes are the information required for sgmv kernel # these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches, self._seq_start_locs = torch.empty(max_batches,
dtype=torch.long, dtype=torch.long,
...@@ -171,7 +171,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -171,7 +171,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
def _update_base_metadata( def _update_base_metadata(
self, self,
mapping: "LoRAMapping", mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]], lora_index_to_id: list[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
...@@ -228,8 +228,8 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -228,8 +228,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self, self,
indices: torch.Tensor, indices: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
): ):
"""Applies bias to output """Applies bias to output
...@@ -259,7 +259,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -259,7 +259,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
@property @property
def prefill_metadata( def prefill_metadata(
self self
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
""" """
This property provides a convenient way to access the necessary This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations. metadata for prefill-related kernel computations.
...@@ -323,7 +323,7 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -323,7 +323,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
def update_metadata( def update_metadata(
self, self,
mapping: "LoRAMapping", mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]], lora_index_to_id: list[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
...@@ -341,8 +341,8 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -341,8 +341,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
self.is_prefill = False self.is_prefill = False
@abstractmethod @abstractmethod
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]: scale: float, **kwargs) -> Optional[torch.Tensor]:
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
...@@ -352,9 +352,9 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -352,9 +352,9 @@ class PunicaWrapperBase(PunicaWrapperABC):
y[i] += (x @ lora_a_stacked[i]) * scale y[i] += (x @ lora_a_stacked[i]) * scale
Args: Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation scale (float): Scaling factor for the operation
""" """
...@@ -364,10 +364,10 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -364,10 +364,10 @@ class PunicaWrapperBase(PunicaWrapperABC):
@abstractmethod @abstractmethod
def add_expand(self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs) -> Optional[torch.Tensor]: **kwargs) -> Optional[torch.Tensor]:
...@@ -384,11 +384,11 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -384,11 +384,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0 offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
...@@ -422,13 +422,13 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -422,13 +422,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None, buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> Optional[torch.Tensor]: **kwargs) -> Optional[torch.Tensor]:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
...@@ -445,12 +445,12 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -445,12 +445,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
# TODO: implement it based on torch ops # TODO: implement it based on torch ops
raise NotImplementedError raise NotImplementedError
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Union
import torch import torch
...@@ -150,8 +150,8 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -150,8 +150,8 @@ class PunicaWrapperCPU(PunicaWrapperBase):
shrink_fun(y, x, w_t_all, scale) shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org) y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs): scale: float, **kwargs):
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
...@@ -165,9 +165,9 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -165,9 +165,9 @@ class PunicaWrapperCPU(PunicaWrapperBase):
y[i] += (x @ lora_a_stacked[i]) * scale y[i] += (x @ lora_a_stacked[i]) * scale
Args: Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation scale (float): Scaling factor for the operation
""" """
...@@ -179,10 +179,10 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -179,10 +179,10 @@ class PunicaWrapperCPU(PunicaWrapperBase):
def add_expand(self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs) -> None: **kwargs) -> None:
...@@ -198,11 +198,11 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -198,11 +198,11 @@ class PunicaWrapperCPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
...@@ -250,13 +250,13 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -250,13 +250,13 @@ class PunicaWrapperCPU(PunicaWrapperBase):
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None, buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> None: **kwargs) -> None:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
...@@ -273,12 +273,12 @@ class PunicaWrapperCPU(PunicaWrapperBase): ...@@ -273,12 +273,12 @@ class PunicaWrapperCPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
......
...@@ -6,7 +6,7 @@ Punica: Multi-Tenant LoRA Serving. ...@@ -6,7 +6,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final from typing import TYPE_CHECKING, Optional, Union, final
import torch import torch
...@@ -57,7 +57,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -57,7 +57,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
def update_metadata( def update_metadata(
self, self,
mapping: LoRAMapping, mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]], lora_index_to_id: list[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
...@@ -74,7 +74,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -74,7 +74,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
def add_shrink(self, y: torch.Tensor, x: torch.Tensor, def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, lora_a_stacked: tuple[torch.Tensor,
...], scale: float, **kwargs): ...], scale: float, **kwargs):
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
...@@ -86,7 +86,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -86,7 +86,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensors y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation scale (float): Scaling factor for the operation
""" """
...@@ -102,9 +102,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -102,9 +102,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
def add_expand(self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs) -> None: **kwargs) -> None:
...@@ -121,10 +121,10 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -121,10 +121,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
...@@ -181,11 +181,11 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -181,11 +181,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
*, *,
buffer: Optional[torch.Tensor] = None, buffer: Optional[torch.Tensor] = None,
**kwargs) -> None: **kwargs) -> None:
...@@ -204,11 +204,11 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -204,11 +204,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None. buffer (Optional[torch.Tensor]): Defaults to None.
""" """
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final from typing import TYPE_CHECKING, Optional, Union, final
import torch import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
...@@ -28,7 +28,7 @@ class PunicaWrapperHPU(PunicaWrapperBase): ...@@ -28,7 +28,7 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def _update_base_metadata( def _update_base_metadata(
self, self,
mapping: "LoRAMapping", mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]], lora_index_to_id: list[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
...@@ -48,9 +48,9 @@ class PunicaWrapperHPU(PunicaWrapperBase): ...@@ -48,9 +48,9 @@ class PunicaWrapperHPU(PunicaWrapperBase):
# graph accumulation. Hence HPU appends `lora_offset` to a list and # graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready. # converts it to a tensor only after it is ready.
if long_lora_context: if long_lora_context:
index_mapping_indices: List[int] = list( index_mapping_indices: list[int] = list(
mapping.index_mapping).copy() mapping.index_mapping).copy()
long_lora_offsets: List[int] = [] long_lora_offsets: list[int] = []
for i in range(len(index_mapping_indices)): for i in range(len(index_mapping_indices)):
lora_offset: int = long_lora_context.offsets_by_lora_id.get( lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0) index_mapping_indices[i], 0)
...@@ -85,13 +85,13 @@ class PunicaWrapperHPU(PunicaWrapperBase): ...@@ -85,13 +85,13 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None, buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> None: **kwargs) -> None:
y_org = y y_org = y
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
...@@ -122,9 +122,9 @@ class PunicaWrapperHPU(PunicaWrapperBase): ...@@ -122,9 +122,9 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def add_shrink( def add_shrink(
self, self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, scale: float,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -133,10 +133,10 @@ class PunicaWrapperHPU(PunicaWrapperBase): ...@@ -133,10 +133,10 @@ class PunicaWrapperHPU(PunicaWrapperBase):
def add_expand( def add_expand(
self, self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs, **kwargs,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -77,8 +77,8 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -77,8 +77,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self._get_token_lora_indices(x), y_offset, self._get_token_lora_indices(x), y_offset,
y_slice_size, add_inputs) y_slice_size, add_inputs)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]: scale: float, **kwargs) -> Optional[torch.Tensor]:
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
...@@ -88,9 +88,9 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -88,9 +88,9 @@ class PunicaWrapperTPU(PunicaWrapperBase):
y[i] += (x @ lora_a_stacked[i]) * scale y[i] += (x @ lora_a_stacked[i]) * scale
Args: Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation scale (float): Scaling factor for the operation
""" """
...@@ -106,10 +106,10 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -106,10 +106,10 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def add_expand(self, def add_expand(self,
y: torch.Tensor, y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
offset_start: int = 0, offset_start: int = 0,
add_inputs=True, add_inputs=True,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
...@@ -125,11 +125,11 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -125,11 +125,11 @@ class PunicaWrapperTPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight bias's weight
output_slices (Tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): Defaults to True.
""" """
y_org = y y_org = y
...@@ -177,13 +177,13 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -177,13 +177,13 @@ class PunicaWrapperTPU(PunicaWrapperBase):
def add_lora_linear(self, def add_lora_linear(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float, scale: float,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
*, *,
buffer: Optional[Tuple[torch.Tensor, ...]] = None, buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
...@@ -200,12 +200,12 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -200,12 +200,12 @@ class PunicaWrapperTPU(PunicaWrapperBase):
Args: Args:
y (torch.Tensor): Output tensor. Will not be changed in-place. y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E) x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor. scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size. output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
""" """
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
...@@ -284,8 +284,8 @@ class PunicaWrapperTPU(PunicaWrapperBase): ...@@ -284,8 +284,8 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self, self,
indices: torch.Tensor, indices: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
output_slices: Tuple[int, ...], output_slices: tuple[int, ...],
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
): ):
"""Applies bias to output """Applies bias to output
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
...@@ -12,7 +12,7 @@ if TYPE_CHECKING: ...@@ -12,7 +12,7 @@ if TYPE_CHECKING:
def compute_meta( def compute_meta(
token_lora_tensor: torch.Tensor token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
""" """
Get the information required for the sgmv kernel. With the features: Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function 1. If consecutive requests in the batch use the same LoRA, this function
...@@ -43,14 +43,14 @@ def compute_meta( ...@@ -43,14 +43,14 @@ def compute_meta(
# TODO see if this can be vectorized # TODO see if this can be vectorized
def convert_mapping( def convert_mapping(
mapping: "LoRAMapping", mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]], lora_index_to_id: list[Optional[int]],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
device: torch.device, device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None, long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]: Optional[torch.Tensor], list[int]]:
"""Converts LoRAMapping to index tensors. """Converts LoRAMapping to index tensors.
Args: Args:
...@@ -84,7 +84,7 @@ def convert_mapping( ...@@ -84,7 +84,7 @@ def convert_mapping(
(base_indices, sampler_indices, sampler_indices_padded, (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). embeddings_indices, long_lora_indices).
""" """
index_mapping_indices: List[int] = list(mapping.index_mapping).copy() index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy() embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None long_lora_offsets: Optional[torch.Tensor] = None
...@@ -92,7 +92,7 @@ def convert_mapping( ...@@ -92,7 +92,7 @@ def convert_mapping(
long_lora_offsets = torch.zeros(len(index_mapping_indices), long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device, device=device,
dtype=torch.long) dtype=torch.long)
prompt_mapping: List[int] = [ prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1 lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping for x in mapping.prompt_mapping
] ]
...@@ -109,7 +109,7 @@ def convert_mapping( ...@@ -109,7 +109,7 @@ def convert_mapping(
index_mapping_indices[i], 0) index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [ indices_list: list[Union[list[int], torch.Tensor]] = [
index_mapping_indices, index_mapping_indices,
lora_indices, lora_indices,
embedding_indices, embedding_indices,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import AbstractSet, Dict, Optional from typing import Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -40,9 +41,9 @@ class LoRAResolver(ABC): ...@@ -40,9 +41,9 @@ class LoRAResolver(ABC):
@dataclass @dataclass
class _LoRAResolverRegistry: class _LoRAResolverRegistry:
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict) resolvers: dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> AbstractSet[str]: def get_supported_resolvers(self) -> Set[str]:
"""Get all registered resolver names.""" """Get all registered resolver names."""
return self.resolvers.keys() return self.resolvers.keys()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import os import os
import re import re
from typing import List, Optional, Set, Tuple, Type, Union from typing import Optional, Union
import huggingface_hub import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
...@@ -37,7 +37,7 @@ from vllm.model_executor.models.utils import WeightsMapper ...@@ -37,7 +37,7 @@ from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__) logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
...@@ -58,7 +58,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { ...@@ -58,7 +58,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
def from_layer(layer: nn.Module, def from_layer(layer: nn.Module,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: List, packed_modules_list: list,
model_config: Optional[PretrainedConfig] = None) -> nn.Module: model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes: for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
...@@ -99,7 +99,7 @@ def replace_submodule(model: nn.Module, module_name: str, ...@@ -99,7 +99,7 @@ def replace_submodule(model: nn.Module, module_name: str,
def parse_fine_tuned_lora_name( def parse_fine_tuned_lora_name(
name: str, name: str,
weights_mapper: Optional[WeightsMapper] = None weights_mapper: Optional[WeightsMapper] = None
) -> Tuple[str, bool, bool]: ) -> tuple[str, bool, bool]:
"""Parse the name of lora weights. """Parse the name of lora weights.
args: args:
...@@ -108,7 +108,7 @@ def parse_fine_tuned_lora_name( ...@@ -108,7 +108,7 @@ def parse_fine_tuned_lora_name(
weights_mapper: maps the name of weight, e.g. weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`, `model.` -> `language_model.model.`,
return: return:
Tuple(module_name, is_lora_a): tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1, module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b. is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias. is_bias whether the tensor is lora bias.
...@@ -147,8 +147,8 @@ def parse_fine_tuned_lora_name( ...@@ -147,8 +147,8 @@ def parse_fine_tuned_lora_name(
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(load_modules: Union[str, List[str]], def is_regex_target_modules(load_modules: Union[str, list[str]],
expected_lora_modules: List[str]) -> bool: expected_lora_modules: list[str]) -> bool:
""" """
PEFT supports passing `target_modules` in the form of regular expressions, PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
...@@ -179,11 +179,11 @@ def is_regex_target_modules(load_modules: Union[str, List[str]], ...@@ -179,11 +179,11 @@ def is_regex_target_modules(load_modules: Union[str, List[str]],
return False return False
def get_supported_lora_modules(model: nn.Module) -> List[str]: def get_supported_lora_modules(model: nn.Module) -> list[str]:
""" """
In vLLM, all linear layers support LoRA. In vLLM, all linear layers support LoRA.
""" """
supported_lora_modules: Set[str] = set() supported_lora_modules: set[str] = set()
# step1: traverse the model to get all the linear subfixes. # step1: traverse the model to get all the linear subfixes.
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, (LinearBase, )): if isinstance(module, (LinearBase, )):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union from typing import Any, Literal, Optional, Union
import torch import torch
...@@ -27,7 +27,7 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -27,7 +27,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
Every request, the requested LoRAs will be loaded (unless they are already Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded.""" loaded), and every other LoRA will be unloaded."""
_manager_cls: Type[LoRAModelManager] = LoRAModelManager _manager_cls: type[LoRAModelManager] = LoRAModelManager
def __init__( def __init__(
self, self,
...@@ -36,9 +36,9 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -36,9 +36,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
device: torch.device, device: torch.device,
embedding_modules: Dict[str, str], embedding_modules: dict[str, str],
embedding_padding_modules: List[str], embedding_padding_modules: list[str],
lora_model_cls: Type[LoRAModel] = LoRAModel, lora_model_cls: type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None, max_position_embeddings: Optional[int] = None,
): ):
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
...@@ -88,7 +88,7 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -88,7 +88,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
self._adapter_manager.supported_lora_modules) self._adapter_manager.supported_lora_modules)
packed_modules_mapping = ( packed_modules_mapping = (
self._adapter_manager.packed_modules_mapping) self._adapter_manager.packed_modules_mapping)
expected_lora_modules: List[str] = [] expected_lora_modules: list[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend( expected_lora_modules.extend(
...@@ -162,12 +162,12 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -162,12 +162,12 @@ class WorkerLoRAManager(AbstractWorkerManager):
def pin_adapter(self, adapter_id: int) -> bool: def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id) return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: Set[Any], def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None: mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters, set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping) self._adapter_manager.set_adapter_mapping)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None: def _apply_adapters(self, adapter_requests: set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters, apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots, self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter) self.remove_adapter, self.add_adapter)
...@@ -184,7 +184,7 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -184,7 +184,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
def remove_all_adapters(self): def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters() self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]: def list_adapters(self) -> set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters) return list_adapters_worker(self._adapter_manager.list_adapters)
...@@ -195,7 +195,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -195,7 +195,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
(unless they are already loaded) and least recently used LoRAs will (unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity.""" be unloaded if the cache is above capacity."""
_manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager( def create_lora_manager(
self, self,
...@@ -213,7 +213,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -213,7 +213,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
self._adapter_manager = lora_manager self._adapter_manager = lora_manager
return lora_manager.model return lora_manager.model
def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request for lora_request in lora_requests if lora_request
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment