"vllm/vscode:/vscode.git/clone" did not exist on "bbc3619dc806b25fd5e14eef90819052ab76e1c6"
Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
if TYPE_CHECKING:
pass
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs['lora_config'].fully_sharded_loras)
return dec
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
def _mcp_apply_weights(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights(
layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device)
for idx in range(n):
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
dispatch_bgmv_low_level(output, buffers[idx],
layer.lora_b_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0,
left_offset, shard_size)
left_offset += shard_size
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
for i in range(2)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] +
shard_size[i]] if lora_a[i] is not None else None
for i in range(3)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0,
start_idx, shard_size)
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
# pylint: disable=unused-argument
import inspect
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import torch.nn as nn
......@@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.distributed.utils import divide
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
......@@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
VocabParallelEmbedding)
if TYPE_CHECKING:
pass
......@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
condition = (not kwargs['lora_config'].fully_sharded_loras
if decorate else True)
return can_replace(*args, **kwargs) and condition
return dec
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
......@@ -130,6 +145,14 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights(
self,
max_loras: int,
......@@ -176,6 +199,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: Optional[Tuple[int, int]]
self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights(
self,
......@@ -233,9 +258,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
......@@ -267,6 +293,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
......@@ -313,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
"""
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
......@@ -327,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -343,15 +380,27 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2]
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def set_lora(
self,
index: int,
......@@ -360,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
......@@ -384,10 +432,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
......@@ -411,7 +458,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
......@@ -422,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
......@@ -447,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0]
......@@ -455,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -475,8 +529,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
device=self.device,
) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
......@@ -484,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
]
return lora_b
def set_lora(
self,
index: int,
......@@ -494,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[0][:,
start_idx:end_idx], lora_b[1][:,
start_idx:end_idx]
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_a[0] is not None:
self.lora_a_stacked[0][
......@@ -517,10 +579,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
......@@ -532,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return output
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
......@@ -623,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -645,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -653,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
......@@ -690,7 +756,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# lazily initialized.
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
......@@ -700,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
def set_lora(
self,
index: int,
......@@ -710,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.reset_lora(index)
if self.tp_size > 1:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
else:
if lora_b[0] is not None:
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_b[1] is not None:
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_b[2] is not None:
self.lora_b_stacked[2][
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
lora_b[2].T, non_blocking=True)
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_b[0] is not None:
lora_b_q = lora_b[0]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
......@@ -758,10 +828,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
......@@ -773,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
return output
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
......@@ -794,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros(
(
max_loras,
......@@ -804,23 +876,40 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
tp_size = get_tensor_model_parallel_world_size()
lora_b_output_size_per_partition = (
self.output_size if not lora_config.fully_sharded_loras else
divide(self.output_size, tp_size))
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.output_size,
lora_b_output_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b
def set_lora(
self,
index: int,
......@@ -829,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
......@@ -854,9 +941,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
......@@ -889,7 +975,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
......@@ -911,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.base_layer, "weight") else self.base_layer.qweight
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
......@@ -991,9 +1078,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.indices_padded: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
......@@ -1091,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
cls
for cls in globals().values() if inspect.isclass(cls)
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
......@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
......@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]
@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
def pack(
cls, loras: List[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
......@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self
@property
......
......@@ -3,7 +3,7 @@ import json
import math
import os
import re
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple, Type
import safetensors.torch
import torch
......@@ -11,10 +11,10 @@ from torch import nn
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
from_layer_logits_processor)
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__)
......@@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
......@@ -149,6 +151,7 @@ class LoRAModel:
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
......@@ -171,6 +174,7 @@ class LoRAModel:
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
......@@ -295,11 +299,10 @@ class LoRAModelManager:
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"):
......@@ -312,7 +315,7 @@ class LoRAModelManager:
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
......@@ -342,8 +345,8 @@ class LoRAModelManager:
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
logger.debug(
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
......@@ -370,7 +373,7 @@ class LoRAModelManager:
return True
return False
def _add_lora(self, lora: LoRAModel) -> bool:
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
......@@ -418,7 +421,7 @@ class LoRAModelManager:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
......@@ -467,6 +470,7 @@ class LoRAModelManager:
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
......@@ -500,7 +504,7 @@ class LoRAModelManager:
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
......@@ -538,7 +542,7 @@ class LoRAModelManager:
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
......@@ -557,13 +561,13 @@ class LoRAModelManager:
class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
def _on_remove(self, key: int, value: LoRAModel):
logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
......
......@@ -49,6 +49,49 @@ def bgmv(
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, indicies: torch.LongTensor,
layer_idx: int, scale: float, y_offset: int,
y_slice_size: int):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
x.size(1),
y_slice_size,
y_offset,
)
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
......
from typing import Tuple
from typing import List, Optional, Set, Tuple, Type
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
......
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Optional, Set, Type
from typing import Any, Dict, List, Set, Type
import torch
......@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest],
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
......@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@abstractmethod
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
...
@abstractmethod
......@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
......@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
self._lora_manager: LoRAModelManager = lora_manager
self._lora_manager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest],
......@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
......@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
self._lora_manager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
......@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
loaded = self._lora_manager.get_lora(
lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded
......@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache
......
......@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode)
tokenizer, mode,
request.guided_whitespace_pattern)
logits_processor = copy(result)
# reset logits processor's internal state
......@@ -117,9 +118,10 @@ def _get_guide_and_mode(
@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode):
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:
......
......@@ -18,7 +18,7 @@ import json
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union
from typing import Callable, DefaultDict, Dict, List, Union
import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
......@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self,
schema: Union[str, Dict, BaseModel],
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None):
whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
......
......@@ -67,6 +67,9 @@ class GeluAndMul(nn.Module):
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module):
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
......@@ -203,14 +203,15 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
......@@ -220,8 +221,9 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
......@@ -232,10 +234,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert sorted_token_ids.stride(0) == 1
if not use_fp8:
A_scale = None
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
......@@ -296,8 +298,8 @@ def get_moe_configs(E: int, N: int,
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
logger.info("Using configuration from %s for MoE layer.",
config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
......@@ -318,6 +320,8 @@ def fused_moe(
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -430,10 +434,13 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
......@@ -443,7 +450,7 @@ def fused_moe(
False,
topk_ids.shape[1],
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
......@@ -451,6 +458,7 @@ def fused_moe(
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
......@@ -460,7 +468,7 @@ def fused_moe(
True,
1,
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)
if inplace:
......
......@@ -64,3 +64,8 @@ class RMSNorm(nn.Module):
self.variance_epsilon,
)
return out
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
......@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
......@@ -28,7 +29,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC):
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
......@@ -53,22 +54,15 @@ class LinearMethodBase(ABC):
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
......@@ -96,10 +90,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
......@@ -116,8 +110,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module):
"""Replicated linear layer.
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
......@@ -125,17 +119,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -146,12 +139,46 @@ class ReplicatedLinear(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
......@@ -161,12 +188,19 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self, x, bias)
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(torch.nn.Module):
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
......@@ -183,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
......@@ -196,34 +230,28 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
):
super().__init__()
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
if output_sizes is None:
output_sizes = [output_size]
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
......@@ -237,6 +265,10 @@ class ColumnParallelLinear(torch.nn.Module):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = param.data
......@@ -245,6 +277,12 @@ class ColumnParallelLinear(torch.nn.Module):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
......@@ -255,7 +293,8 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(self, input_, bias)
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
......@@ -264,6 +303,14 @@ class ColumnParallelLinear(torch.nn.Module):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
......@@ -283,7 +330,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
......@@ -294,13 +341,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
<<<<<<< HEAD
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method,
=======
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, quant_config,
>>>>>>> v0.4.2
self.output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
......@@ -311,7 +363,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
......@@ -325,14 +382,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
......@@ -347,15 +403,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
......@@ -368,11 +423,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
......@@ -413,7 +474,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
......@@ -425,7 +486,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
......@@ -453,8 +514,12 @@ class QKVParallelLinear(ColumnParallelLinear):
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
<<<<<<< HEAD
params_dtype, linear_method, output_sizes)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
=======
params_dtype, quant_config, output_sizes)
>>>>>>> v0.4.2
def weight_loader(self,
param: Parameter,
......@@ -462,7 +527,11 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
if loaded_shard_id is None:
# Loaded weight is already packed.
......@@ -480,14 +549,14 @@ class QKVParallelLinear(ColumnParallelLinear):
]
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
......@@ -509,6 +578,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
......@@ -516,8 +586,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
......@@ -534,12 +603,17 @@ class QKVParallelLinear(ColumnParallelLinear):
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
......@@ -559,7 +633,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module):
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
......@@ -582,7 +656,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
quant_config: Quantization configure.
"""
def __init__(
......@@ -594,32 +668,26 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
......@@ -637,6 +705,10 @@ class RowParallelLinear(torch.nn.Module):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
......@@ -645,6 +717,12 @@ class RowParallelLinear(torch.nn.Module):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
......@@ -662,8 +740,8 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self, input_parallel)
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
......@@ -676,3 +754,11 @@ class RowParallelLinear(torch.nn.Module):
output = output_
output_bias = self.bias
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
......@@ -70,6 +70,12 @@ class LogitsProcessor(nn.Module):
logits = logits[:, :self.org_vocab_size]
return logits
def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", forg_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
def _prune_hidden_states(
hidden_states: torch.Tensor,
......@@ -83,30 +89,27 @@ def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
logits_processed = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
token_ids = seq_group.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0]
assert logits_processed == logits.shape[0]
return logits
from typing import Type
from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = {
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": FP8Config,
"fp8": Fp8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig,
}
......
......@@ -8,11 +8,11 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype:
......@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
def get_linear_method(self) -> "AQLMLinearMethod":
return AQLMLinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
......
......@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class AWQConfig(QuantizationConfig):
......@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
def get_linear_method(self) -> "AWQLinearMethod":
return AWQLinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
......@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
......
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import torch
from torch import nn
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class QuantizationConfig(ABC):
......@@ -51,8 +76,16 @@ class QuantizationConfig(ABC):
"quantization config.")
@abstractmethod
def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError
@abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment