Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
if TYPE_CHECKING:
pass
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs['lora_config'].fully_sharded_loras)
return dec
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
def _mcp_apply_weights(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights(
layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device)
for idx in range(n):
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
dispatch_bgmv_low_level(output, buffers[idx],
layer.lora_b_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0,
left_offset, shard_size)
left_offset += shard_size
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
for i in range(2)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] +
shard_size[i]] if lora_a[i] is not None else None
for i in range(3)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0,
start_idx, shard_size)
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
# pylint: disable=unused-argument # pylint: disable=unused-argument
import inspect
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.distributed.utils import divide
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) VocabParallelEmbedding)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
...@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: ...@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
raise ValueError(f"Unsupported base layer: {base_layer}") raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
condition = (not kwargs['lora_config'].fully_sharded_loras
if decorate else True)
return can_replace(*args, **kwargs) and condition
return dec
def _apply_lora( def _apply_lora(
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: torch.Tensor, lora_a_stacked: torch.Tensor,
...@@ -130,6 +145,14 @@ class LoRAMapping: ...@@ -130,6 +145,14 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
...@@ -176,6 +199,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -176,6 +199,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None: def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -233,9 +258,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -233,9 +258,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2], self.lora_a_stacked.shape[2],
) )
self.indices: Optional[torch.Tensor] = None # Lazily initialized.
self.indices_len: Optional[List[int]] = None self.indices: torch.Tensor
self.embeddings_indices = None self.indices_len: List[int]
self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -267,6 +293,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -267,6 +293,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2] self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]] )[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping( def set_mapping(
...@@ -313,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -313,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
"""
def __init__(self, base_layer: ColumnParallelLinear) -> None: def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__() super().__init__()
...@@ -327,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -327,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -343,15 +380,27 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -343,15 +380,27 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2] self.output_dim = self.lora_b_stacked.shape[2]
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -360,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -360,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
shard_size = self.output_dim lora_b = self.slice_lora_b(lora_b)
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True) lora_a.T, non_blocking=True)
...@@ -384,10 +432,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -384,10 +432,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices self.indices = base_indices
self.indices_len = indices_len self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -411,7 +458,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -411,7 +458,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if not self.base_layer.skip_bias_add else None) if not self.base_layer.skip_bias_add else None)
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply_weights(input_, bias) output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output: if self.base_layer.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
...@@ -422,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -422,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return output, output_bias return output, output_bias
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
...@@ -447,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -447,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
n_slices = 2 n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0] and self.base_layer.output_sizes[0]
...@@ -455,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -455,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"LoRAColumnParallelLinear2Slice requires 2 slices with " "LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.") "the same size.")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = tuple( self.lora_a_stacked = tuple(
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -475,8 +529,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -475,8 +529,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
device=self.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2] self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
...@@ -484,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -484,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
]
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -494,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -494,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
shard_size = self.output_dim lora_b = self.slice_lora_b(lora_b)
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[0][:,
start_idx:end_idx], lora_b[1][:,
start_idx:end_idx]
if lora_a[0] is not None: if lora_a[0] is not None:
self.lora_a_stacked[0][ self.lora_a_stacked[0][
...@@ -517,10 +579,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -517,10 +579,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True) lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -532,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -532,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return output return output
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
...@@ -623,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -623,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads * self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size) self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size) self.base_layer.head_size)
self.q_shard_id = tp_rank self.q_shard_id = self.tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v # q, k, v
self.lora_a_stacked = ( self.lora_a_stacked = (
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -645,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -645,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -653,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -653,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
...@@ -690,7 +756,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -690,7 +756,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size) self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None # lazily initialized.
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
...@@ -700,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -700,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0 self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -710,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -710,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
if lora_b[0] is not None: lora_a = self.slice_lora_a(lora_a)
lora_b_q = lora_b[0][:, self.q_proj_shard_size * lora_b = self.slice_lora_b(lora_b)
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)] if lora_b[0] is not None:
self.lora_b_stacked[0][ lora_b_q = lora_b[0]
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( self.lora_b_stacked[0][
lora_b_q.T, non_blocking=True) index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
if lora_b[1] is not None: lora_b_q.T, non_blocking=True)
lora_b_k = lora_b[1][:, self.kv_proj_shard_size * if lora_b[1] is not None:
self.kv_shard_id:self.kv_proj_shard_size * lora_b_k = lora_b[1]
(self.kv_shard_id + 1)] self.lora_b_stacked[1][
self.lora_b_stacked[1][ index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( lora_b_k.T, non_blocking=True)
lora_b_k.T, non_blocking=True) if lora_b[2] is not None:
if lora_b[2] is not None: lora_b_v = lora_b[2]
lora_b_v = lora_b[2][:, self.kv_proj_shard_size * self.lora_b_stacked[2][
self.kv_shard_id:self.kv_proj_shard_size * index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
(self.kv_shard_id + 1)] lora_b_v.T, non_blocking=True)
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
else:
if lora_b[0] is not None:
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_b[1] is not None:
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_b[2] is not None:
self.lora_b_stacked[2][
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
lora_b[2].T, non_blocking=True)
if lora_a[0] is not None: if lora_a[0] is not None:
self.lora_a_stacked[0][ self.lora_a_stacked[0][
...@@ -758,10 +828,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -758,10 +828,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True) lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -773,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -773,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
return output return output
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
...@@ -794,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -794,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
...@@ -804,23 +876,40 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -804,23 +876,40 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
tp_size = get_tensor_model_parallel_world_size()
lora_b_output_size_per_partition = (
self.output_size if not lora_config.fully_sharded_loras else
divide(self.output_size, tp_size))
self.lora_b_stacked = torch.zeros( self.lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
1, 1,
self.output_size, lora_b_output_size_per_partition,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None # Lazily initialized
self.indices_len: Optional[List[int]] = None self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
...@@ -829,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -829,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
if self.base_layer.tp_size > 1: if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
shard_size = self.input_size lora_b = self.slice_lora_b(lora_b)
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
...@@ -854,9 +941,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -854,9 +941,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices self.indices = base_indices
self.indices_len = indices_len self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x)
self.base_layer, x)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
...@@ -889,7 +975,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -889,7 +975,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply_weights(input_parallel) output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1: if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
...@@ -911,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -911,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.base_layer, "weight") else self.base_layer.qweight self.base_layer, "weight") else self.base_layer.qweight
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
...@@ -991,9 +1078,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -991,9 +1078,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
self.indices = None # Lazily initialized.
self.indices_padded = None self.indices: torch.Tensor
self.indices_len = None self.indices_len: List[int]
self.indices_padded: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -1091,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1091,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
cls
for cls in globals().values() if inspect.isclass(cls)
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
...@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self, self,
module_name: str, module_name: str,
rank: int, rank: int,
lora_alphas: List[int], lora_alphas: List[Optional[int]],
lora_a: List[torch.Tensor], lora_a: List[Optional[torch.Tensor]],
lora_b: List[torch.Tensor], lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None, scaling: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0, lora_alpha=0,
lora_a=lora_a, lora_a=lora_a,
lora_b=lora_b, lora_b=lora_b,
scaling=scaling, scaling=scaling, # type: ignore
embeddings_tensor=None, embeddings_tensor=None,
) )
self.lora_alphas = lora_alphas self.lora_alphas = lora_alphas
if scaling is None: if scaling is None:
self.scaling = [ self.scaling = [ # type: ignore
lora_alpha / self.rank for lora_alpha in self.lora_alphas lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
] ]
@classmethod @classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA. """Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA. If LoRA is None, it signifies that the submodule does not have a LoRA.
...@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras]) scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj return obj
def optimize(self) -> "PackedLoRALayerWeights": def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b.""" """Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)): for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None: if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue continue
self.lora_b[i] *= self.scaling[i] self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 self.scaling[i] = 1 # type: ignore
return self return self
@property @property
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import math import math
import os import os
import re import re
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type from typing import Callable, Dict, List, Optional, Tuple, Type
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -11,10 +11,10 @@ from torch import nn ...@@ -11,10 +11,10 @@ from torch import nn
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from_layer_logits_processor)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache, is_pin_memory_available from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,44 +53,46 @@ def convert_mapping( ...@@ -53,44 +53,46 @@ def convert_mapping(
embeddings. embeddings.
indices_len: List of lengths of the above tensors. indices_len: List of lengths of the above tensors.
""" """
indices = list(mapping.index_mapping).copy() index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = indices.copy() embedding_indices = index_mapping_indices.copy()
lora_indices = indices.copy() lora_indices = index_mapping_indices.copy()
prompt_mapping = [ prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1 lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping for x in mapping.prompt_mapping
] ]
lora_idx = None lora_idx = None
for i in range(len(indices)): for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize # TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i]) lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if indices[i] > 0 else -1) if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0 embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
indices[i] = i index_mapping_indices[i] = i
lora_indices[i] = lora_idx lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices], indices = torch.tensor(
dtype=torch.long, [index_mapping_indices, lora_indices, embedding_indices],
device="cuda") dtype=torch.long,
prompt_mapping = torch.tensor(prompt_mapping, device="cuda")
device="cuda", prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long) device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([ embeddings_indices = torch.stack([
indices[2] * extra_vocab_size, indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size) indices[2] * (vocab_size + extra_vocab_size)
]) ])
embeddings_indices[embeddings_indices == -1] = max_loras - 1 embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1] base_indices = indices[1]
sampler_indices = prompt_mapping sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone() sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = ( sampler_indices_padded = (
torch.arange( torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded))) (sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], indices_len = [
sampler_indices_padded.shape[-1], base_indices.shape[-1], sampler_indices.shape[-1],
embeddings_indices.shape[-1]) sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
return (base_indices, sampler_indices, sampler_indices_padded, return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len) embeddings_indices, indices_len)
...@@ -149,6 +151,7 @@ class LoRAModel: ...@@ -149,6 +151,7 @@ class LoRAModel:
if module_name not in loras: if module_name not in loras:
lora_embeddings_tensor = None lora_embeddings_tensor = None
if embeddings: if embeddings:
assert embedding_modules is not None
embeddings_module = next( embeddings_module = next(
(k for k in embedding_modules if k in module_name), (k for k in embedding_modules if k in module_name),
None) None)
...@@ -171,6 +174,7 @@ class LoRAModel: ...@@ -171,6 +174,7 @@ class LoRAModel:
else: else:
loras[module_name].lora_b = tensor.to(device=device, loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t() dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name if any(name in module_name
for name in embedding_padding_modules for name in embedding_padding_modules
) and target_embedding_padding is not None: ) and target_embedding_padding is not None:
...@@ -295,11 +299,10 @@ class LoRAModelManager: ...@@ -295,11 +299,10 @@ class LoRAModelManager:
self.max_num_batched_tokens, self.max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above # 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices # embeddings_indices
self.indices_len = [None] * 4 self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
...@@ -312,7 +315,7 @@ class LoRAModelManager: ...@@ -312,7 +315,7 @@ class LoRAModelManager:
self._registered_loras: Dict[int, LoRAModel] = {} self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {} self._active_loras: Dict[int, None] = {}
self._last_mapping = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
...@@ -342,8 +345,8 @@ class LoRAModelManager: ...@@ -342,8 +345,8 @@ class LoRAModelManager:
index, _ = first_free_slot index, _ = first_free_slot
self._active_loras[lora_id] = None self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id] lora_model = self._registered_loras[lora_id]
logger.debug( logger.debug("Activating LoRA. int id: %d, slot index: %d",
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items(): for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name) module_lora = lora_model.get_lora(module_name)
...@@ -370,7 +373,7 @@ class LoRAModelManager: ...@@ -370,7 +373,7 @@ class LoRAModelManager:
return True return True
return False return False
def _add_lora(self, lora: LoRAModel) -> bool: def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora) self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora self._registered_loras[lora.id] = lora
...@@ -418,7 +421,7 @@ class LoRAModelManager: ...@@ -418,7 +421,7 @@ class LoRAModelManager:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]: def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None) return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool: def remove_all_loras(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
self._registered_loras.clear() self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots self.lora_index_to_id = [None] * self.lora_slots
...@@ -467,6 +470,7 @@ class LoRAModelManager: ...@@ -467,6 +470,7 @@ class LoRAModelManager:
continue continue
parts = module_name.split(".") parts = module_name.split(".")
if module_name not in self.packed_modules: if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules: if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size + input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if self.lora_config.lora_extra_vocab_size if
...@@ -500,7 +504,7 @@ class LoRAModelManager: ...@@ -500,7 +504,7 @@ class LoRAModelManager:
else: else:
parts = module_name.split(".") parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]] replacements = self.packed_modules_mapping[parts[-1]]
subloras = [] subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements): for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r, module_name + "." + r,
...@@ -538,7 +542,7 @@ class LoRAModelManager: ...@@ -538,7 +542,7 @@ class LoRAModelManager:
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items(): for module_name, new_module_names in self.packed_modules.items():
replacement_loras = [] replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = lora_model.get_lora(r) lora = lora_model.get_lora(r)
...@@ -557,13 +561,13 @@ class LoRAModelManager: ...@@ -557,13 +561,13 @@ class LoRAModelManager:
class LoRALRUCache(LRUCache[LoRAModel]): class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
None]): bool]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: LoRAModel): def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}") logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key) self.deactivate_lora_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)
......
...@@ -49,6 +49,49 @@ def bgmv( ...@@ -49,6 +49,49 @@ def bgmv(
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, indicies: torch.LongTensor,
layer_idx: int, scale: float, y_offset: int,
y_slice_size: int):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
x.size(1),
y_slice_size,
y_offset,
)
def add_lora(y: torch.Tensor, def add_lora(y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
wa_t_all: torch.Tensor, wa_t_all: torch.Tensor,
......
from typing import Tuple from typing import List, Optional, Set, Tuple, Type
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__) logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(model: nn.Module, module_name: str, def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module: new_module: nn.Module) -> nn.Module:
......
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Optional, Set, Type from typing import Any, Dict, List, Set, Type
import torch import torch
...@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC): ...@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
... ...
@abstractmethod @abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
... ...
...@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC): ...@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
... ...
@abstractmethod @abstractmethod
def remove_all_loras(self) -> bool: def remove_all_loras(self):
... ...
@abstractmethod @abstractmethod
...@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_padding_modules: List[str], embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel, lora_model_cls: Type[LoRAModel] = LoRAModel,
): ):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device) lora_config, device)
...@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_config=self.lora_config, lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls, lora_manager_cls=self._lora_manager_cls,
) )
self._lora_manager: LoRAModelManager = lora_manager self._lora_manager = lora_manager
return lora_manager.model return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
...@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id) return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool: def remove_all_loras(self):
self._lora_manager.remove_all_loras() self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
...@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config=self.lora_config, lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
) )
self._lora_manager: LRUCacheLoRAModelManager = lora_manager self._lora_manager = lora_manager
return lora_manager.model return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request for lora_request in lora_requests if lora_request
...@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
if lora_request.lora_int_id not in self.list_loras(): if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory # Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity: if len(self._lora_manager) + 1 > self._lora_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora() self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request) lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora) loaded = self._lora_manager.add_lora(lora)
else: else:
# If the lora is already loaded, just touch it to # If the lora is already loaded, just touch it to
# update its position in the caches # update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id) loaded = self._lora_manager.get_lora(
lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id) self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded return loaded
...@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: ...@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return schema return schema
if isinstance(schema, BaseModel): if isinstance(schema, BaseModel):
return schema.model_json_schema() return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache @lru_cache
......
...@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result = await loop.run_in_executor(global_thread_pool, result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide, _get_cached_logits_processor, guide,
tokenizer, mode) tokenizer, mode,
request.guided_whitespace_pattern)
logits_processor = copy(result) logits_processor = copy(result)
# reset logits processor's internal state # reset logits processor's internal state
...@@ -117,9 +118,10 @@ def _get_guide_and_mode( ...@@ -117,9 +118,10 @@ def _get_guide_and_mode(
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode): mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer) return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer) return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR: elif mode == GuidedDecodingMode.GRAMMAR:
......
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
import math import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
...@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class JSONLogitsProcessor(RegexLogitsProcessor): class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, def __init__(self, schema: Union[str, Dict, BaseModel],
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None): whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters
......
...@@ -67,6 +67,9 @@ class GeluAndMul(nn.Module): ...@@ -67,6 +67,9 @@ class GeluAndMul(nn.Module):
ops.gelu_tanh_and_mul(out, x) ops.gelu_tanh_and_mul(out, x)
return out return out
def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module): class NewGELU(nn.Module):
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
...@@ -203,14 +203,15 @@ def moe_align_block_size( ...@@ -203,14 +203,15 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible - The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
sorted_ids = torch.empty( max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
(topk_ids.numel() + num_experts * (block_size - 1), ), sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
...@@ -220,8 +221,9 @@ def moe_align_block_size( ...@@ -220,8 +221,9 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor, A_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor, B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
...@@ -232,10 +234,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -232,10 +234,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if not use_fp8: if not use_fp8:
A_scale = None assert A_scale is None
assert B_scale is None assert B_scale is None
else: else:
A, A_scale = ops.scaled_fp8_quant(A) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
...@@ -296,8 +298,8 @@ def get_moe_configs(E: int, N: int, ...@@ -296,8 +298,8 @@ def get_moe_configs(E: int, N: int,
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info( logger.info("Using configuration from %s for MoE layer.",
f"Using configuration from {config_file_path} for MoE layer.") config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
...@@ -318,6 +320,8 @@ def fused_moe( ...@@ -318,6 +320,8 @@ def fused_moe(
use_fp8: bool = False, use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -430,10 +434,13 @@ def fused_moe( ...@@ -430,10 +434,13 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E) topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states, invoke_fused_moe_kernel(hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1_scale,
w1_scale, w1_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
...@@ -443,7 +450,7 @@ def fused_moe( ...@@ -443,7 +450,7 @@ def fused_moe(
False, False,
topk_ids.shape[1], topk_ids.shape[1],
config, config,
compute_type=tl.float16, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
...@@ -451,6 +458,7 @@ def fused_moe( ...@@ -451,6 +458,7 @@ def fused_moe(
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2_scale,
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
...@@ -460,7 +468,7 @@ def fused_moe( ...@@ -460,7 +468,7 @@ def fused_moe(
True, True,
1, 1,
config, config,
compute_type=tl.float16, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8)
if inplace: if inplace:
......
...@@ -64,3 +64,8 @@ class RMSNorm(nn.Module): ...@@ -64,3 +64,8 @@ class RMSNorm(nn.Module):
self.variance_epsilon, self.variance_epsilon,
) )
return out return out
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import List, Optional from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -28,7 +29,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset): ...@@ -28,7 +29,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
@abstractmethod @abstractmethod
...@@ -53,22 +54,15 @@ class LinearMethodBase(ABC): ...@@ -53,22 +54,15 @@ class LinearMethodBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor. """Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer.""" Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class UnquantizedLinearMethod(LinearMethodBase): class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization. """Linear method without quantization.
...@@ -96,10 +90,10 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -96,10 +90,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight weight = layer.weight
if self.separate_bias_add: if self.separate_bias_add:
if bias is not None: if bias is not None:
...@@ -116,8 +110,8 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -116,8 +110,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module): class LinearBase(torch.nn.Module):
"""Replicated linear layer. """Base linear layer.
Args: Args:
input_size: input dimension of the linear layer. input_size: input dimension of the linear layer.
...@@ -125,17 +119,16 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -125,17 +119,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias. bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it. skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
output_size: int, output_size: int,
bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -146,12 +139,46 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -146,12 +139,46 @@ class ReplicatedLinear(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
if linear_method is None: if quant_config is None:
linear_method = UnquantizedLinearMethod() self.quant_method: Optional[
self.linear_method = linear_method QuantizeMethodBase] = UnquantizedLinearMethod()
self.linear_method.create_weights(self, self.input_size, else:
[self.output_size], self.input_size, self.quant_method = quant_config.get_quant_method(self)
self.output_size, self.params_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype)) torch.empty(self.output_size, dtype=self.params_dtype))
...@@ -161,12 +188,19 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -161,12 +188,19 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self, x, bias) assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along The linear layer is defined as Y = XA + b. A is parallelized along
...@@ -183,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -183,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we bias can be fused with other element-wise operations. we
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3. the list would be size 3.
""" """
...@@ -196,34 +230,28 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -196,34 +230,28 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None, output_sizes: Optional[List[int]] = None,
): ):
super().__init__() super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size) self.output_size_per_partition = divide(output_size, tp_size)
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
self.linear_method = linear_method # All the linear layer supports quant method.
self.linear_method.create_weights(self, assert self.quant_method is not None
self.input_size, self.quant_method.create_weights(self,
[x // tp_size for x in output_sizes], self.input_size,
self.input_size, [x // tp_size for x in output_sizes],
self.output_size, self.input_size,
self.params_dtype, self.output_size,
weight_loader=self.weight_loader) self.params_dtype,
weight_loader=self.weight_loader)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -237,6 +265,10 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -237,6 +265,10 @@ class ColumnParallelLinear(torch.nn.Module):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
param_data = param.data param_data = param.data
...@@ -245,6 +277,12 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -245,6 +277,12 @@ class ColumnParallelLinear(torch.nn.Module):
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn: if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
...@@ -255,7 +293,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -255,7 +293,8 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights(self, input_, bias) assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
...@@ -264,6 +303,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -264,6 +303,14 @@ class ColumnParallelLinear(torch.nn.Module):
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
class MergedColumnParallelLinear(ColumnParallelLinear): class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism. """Packed linear layers with column parallelism.
...@@ -283,7 +330,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -283,7 +330,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we bias can be fused with other element-wise operations. we
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
...@@ -294,13 +341,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -294,13 +341,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
<<<<<<< HEAD
super().__init__(input_size, sum(output_sizes), bias, gather_output, super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method, skip_bias_add, params_dtype, linear_method,
=======
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, quant_config,
>>>>>>> v0.4.2
self.output_sizes) self.output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
...@@ -311,7 +363,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -311,7 +363,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False) is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already packed. # Loaded weight is already packed.
if output_dim is None: if output_dim is None:
...@@ -325,14 +382,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -325,14 +382,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset += output_size current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets: for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -347,15 +403,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -347,15 +403,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -368,11 +423,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -368,11 +423,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
# Special case for AQLM codebooks.
elif is_metadata: elif is_metadata:
# metadata indicates fixed size concatenated along dim 0 # metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0] shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size) param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else: else:
ignore_warning = getattr(param, "ignore_warning", False) ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning: if not ignore_warning:
...@@ -413,7 +474,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -413,7 +474,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we bias can be fused with other element-wise operations. we
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
...@@ -425,7 +486,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -425,7 +486,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
...@@ -453,8 +514,12 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -453,8 +514,12 @@ class QKVParallelLinear(ColumnParallelLinear):
] ]
super().__init__(input_size, output_size, bias, False, skip_bias_add, super().__init__(input_size, output_size, bias, False, skip_bias_add,
<<<<<<< HEAD
params_dtype, linear_method, output_sizes) params_dtype, linear_method, output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
=======
params_dtype, quant_config, output_sizes)
>>>>>>> v0.4.2
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -462,7 +527,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -462,7 +527,11 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False) is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already packed. # Loaded weight is already packed.
...@@ -480,14 +549,14 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -480,14 +549,14 @@ class QKVParallelLinear(ColumnParallelLinear):
] ]
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets: for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to # Special case for Marlin.
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -509,6 +578,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -509,6 +578,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (self.num_heads + shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
...@@ -516,8 +586,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -516,8 +586,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to # Special case for Marlin.
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -534,12 +603,17 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -534,12 +603,17 @@ class QKVParallelLinear(ColumnParallelLinear):
start_idx = shard_id * shard_size start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
# Special case for for AQLM codebooks.
elif is_metadata: elif is_metadata:
# metadata indicates fixed size concatenated along dim 0 # metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0] shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id) shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size, param_data = param_data.narrow(0, shard_index * shard_size,
shard_size) shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else: else:
ignore_warning = getattr(param, "ignore_warning", False) ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning: if not ignore_warning:
...@@ -559,7 +633,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -559,7 +633,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along The linear layer is defined as Y = XA + b. A is parallelized along
...@@ -582,7 +656,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -582,7 +656,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. bias can be fused with other element-wise operations.
We skip adding bias but instead return it. We skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method. quant_config: Quantization configure.
""" """
def __init__( def __init__(
...@@ -594,32 +668,26 @@ class RowParallelLinear(torch.nn.Module): ...@@ -594,32 +668,26 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True, reduce_results: bool = True,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__(input_size, output_size, skip_bias_add, params_dtype,
# Keep input parameters quant_config)
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add # All the linear layer supports quant method.
if linear_method is None: assert self.quant_method is not None
linear_method = UnquantizedLinearMethod() self.quant_method.create_weights(self,
self.linear_method = linear_method self.input_size_per_partition,
self.linear_method.create_weights(self, [self.output_size],
self.input_size_per_partition, self.input_size,
[self.output_size], self.output_size,
self.input_size, self.params_dtype,
self.output_size, weight_loader=self.weight_loader)
self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
...@@ -637,6 +705,10 @@ class RowParallelLinear(torch.nn.Module): ...@@ -637,6 +705,10 @@ class RowParallelLinear(torch.nn.Module):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
param_data = param.data param_data = param.data
...@@ -645,6 +717,12 @@ class RowParallelLinear(torch.nn.Module): ...@@ -645,6 +717,12 @@ class RowParallelLinear(torch.nn.Module):
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn: if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
...@@ -662,8 +740,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -662,8 +740,8 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights( assert self.quant_method is not None
self, input_parallel) output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
...@@ -676,3 +754,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -676,3 +754,11 @@ class RowParallelLinear(torch.nn.Module):
output = output_ output = output_
output_bias = self.bias output_bias = self.bias
return output, output_bias return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
...@@ -70,6 +70,12 @@ class LogitsProcessor(nn.Module): ...@@ -70,6 +70,12 @@ class LogitsProcessor(nn.Module):
logits = logits[:, :self.org_vocab_size] logits = logits[:, :self.org_vocab_size]
return logits return logits
def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", forg_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -83,30 +89,27 @@ def _apply_logits_processors( ...@@ -83,30 +89,27 @@ def _apply_logits_processors(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False found_logits_processors = False
for i, seq_group in enumerate(sampling_metadata.seq_groups): logits_processed = 0
seq_ids, sampling_params = seq_group for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors logits_processors = sampling_params.logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
if logits_processors: if logits_processors:
found_logits_processors = True found_logits_processors = True
for seq_id in seq_ids: for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx] logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids token_ids = seq_group.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors: for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row) logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row logits[logits_row_idx] = logits_row
logits_row_idx += 1
else: logits_processed += len(seq_group.sample_indices) + len(
logits_row_idx += len(seq_ids) seq_group.prompt_logprob_indices)
if found_logits_processors: if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly # verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0] assert logits_processed == logits.shape[0]
return logits return logits
from typing import Type from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"fp8": FP8Config, "fp8": Fp8Config,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig, "marlin": MarlinConfig,
} }
......
...@@ -8,11 +8,11 @@ import torch ...@@ -8,11 +8,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype: def get_int_dtype(nbits: int) -> torch.dtype:
...@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig): ...@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
return cls(in_group_size, nbits_per_codebook, num_code_books, return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size) out_group_size)
def get_linear_method(self) -> "AQLMLinearMethod": def get_quant_method(
return AQLMLinearMethod(self) self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase): ...@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs) set_weight_attrs(scales, extra_weight_attrs)
def apply_weights( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
...@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig): ...@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point) return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod": def get_quant_method(
return AWQLinearMethod(self) self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
...@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs) set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
scales = layer.scales scales = layer.scales
qzeros = layer.qzeros qzeros = layer.qzeros
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import torch import torch
from torch import nn
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
...@@ -51,8 +76,16 @@ class QuantizationConfig(ABC): ...@@ -51,8 +76,16 @@ class QuantizationConfig(ABC):
"quantization config.") "quantization config.")
@abstractmethod @abstractmethod
def get_linear_method(self) -> LinearMethodBase: def get_quant_method(
"""Get the linear method to use for the quantized linear layer.""" self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment