Unverified Commit 8af890a8 authored by Jee Li's avatar Jee Li Committed by GitHub
Browse files

Enable more models to inference based on LoRA (#3382)


Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent dfeb2ecc
...@@ -16,10 +16,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -16,10 +16,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
f(in_T, out_T, W_T, narrow, 1280) \ f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1536) \
f(in_T, out_T, W_T, narrow, 1728) \ f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \ f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \ f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2304) \
f(in_T, out_T, W_T, narrow, 2560) \ f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
...@@ -27,10 +30,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -27,10 +30,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4608) \
f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \ f(in_T, out_T, W_T, narrow, 8192) \
...@@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32256) \
......
...@@ -134,6 +134,16 @@ def gemma_lora_files(): ...@@ -134,6 +134,16 @@ def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
@pytest.fixture(scope="session")
def chatglm3_lora_files():
return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
@pytest.fixture(scope="session")
def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
@pytest.fixture @pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module: def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup() cleanup()
......
import pytest
import vllm
from vllm.lora.request import LoRARequest
from .conftest import cleanup
MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age ASC",
]
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(baichuan_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
del llm_tp1
cleanup()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
del llm_tp2
cleanup()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
del llm_tp4
cleanup()
assert output_tp1 == output_tp4
\ No newline at end of file
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
...@@ -8,12 +8,16 @@ import torch ...@@ -8,12 +8,16 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA, LoRAMapping, LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora, QKVParallelLinearWithLora,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights, from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
convert_mapping) convert_mapping)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -93,8 +97,7 @@ def populate_loras( ...@@ -93,8 +97,7 @@ def populate_loras(
lora_dict: Dict[int, LoRALayerWeights] = dict() lora_dict: Dict[int, LoRALayerWeights] = dict()
# Dictionary that maps the lora ID to the # Dictionary that maps the lora ID to the
# corresponding subloras. Only useful when # corresponding subloras.
# repeats > 1.
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
for slot_idx, lora_id in enumerate(id_to_index): for slot_idx, lora_id in enumerate(id_to_index):
...@@ -607,7 +610,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: ...@@ -607,7 +610,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [2, 3]) @pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
...@@ -623,6 +626,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: ...@@ -623,6 +626,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias=False) bias=False)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear) lora_linear = MergedColumnParallelLinearWithLoRA(linear)
elif repeats == 3:
linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedQKVParallelLinearWithLora(linear)
else: else:
linear = QKVParallelLinear(4096, 64, 32, bias=False) linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
......
...@@ -43,9 +43,10 @@ def _lora_ref_impl( ...@@ -43,9 +43,10 @@ def _lora_ref_impl(
H1 = H2 = [ H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
22016, 24576, 32000, 32256, 32512, 32768, 33024 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
32768, 33024
] ]
SEED = [0xabcdabcd987] SEED = [0xabcdabcd987]
......
# pylint: disable=unused-argument # pylint: disable=unused-argument
import inspect
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -114,8 +115,11 @@ class LoRAMapping: ...@@ -114,8 +115,11 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, def create_lora_weights(
model_config: PretrainedConfig) -> None: self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
... ...
...@@ -144,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -144,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module):
"""Sets the mapping indices.""" """Sets the mapping indices."""
... ...
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...@@ -278,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -278,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.indices[:self.indices_len[0]], 0, 1.0) self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org) return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is VocabParallelEmbedding
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None: def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -309,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -309,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1] self.output_dim = self.lora_b_stacked.shape[2]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -323,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -323,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True) lora_a.T, non_blocking=True)
...@@ -383,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -383,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def linear_weights(self): def linear_weights(self):
return self.base_layer.linear_weights return self.base_layer.linear_weights
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1)
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices) """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
...@@ -485,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -485,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) )
return output return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list) == 2
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
self.q_proj_total_size = (self.base_layer.total_num_heads *
self.base_layer.head_size)
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices) """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj). (q_proj + k_proj + v_proj -> qkv_proj).
...@@ -654,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -654,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
) )
return output return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 3
class RowParallelLinearWithLoRA(BaseLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...@@ -780,6 +890,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -780,6 +890,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def weight(self): def weight(self):
return self.base_layer.weight return self.base_layer.weight
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is RowParallelLinear
class LogitsProcessorWithLoRA(BaseLayerWithLoRA): class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...@@ -900,7 +1016,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -900,7 +1016,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
embedding: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None: if embedding_bias is not None:
...@@ -949,22 +1065,30 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -949,22 +1065,30 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs) return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
cls
for cls in globals().values() if inspect.isclass(cls)
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}
def from_layer( def from_layer(layer: nn.Module,
layer: nn.Module,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: packed_modules_list: List,
supported_layer_types = { model_config: Optional[PretrainedConfig] = None) -> nn.Module:
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, for lora_cls in _all_lora_classes:
ColumnParallelLinear: ColumnParallelLinearWithLoRA, if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
QKVParallelLinear: QKVParallelLinearWithLora, model_config):
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, ret = lora_cls(layer)
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer)
ret.create_lora_weights(max_loras, lora_config, model_config) ret.create_lora_weights(max_loras, lora_config, model_config)
return ret return ret
return layer return layer
......
...@@ -413,11 +413,12 @@ class LoRAModelManager: ...@@ -413,11 +413,12 @@ class LoRAModelManager:
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule( new_module = replace_submodule(
self.model, module_name, self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config, from_layer(module, self.lora_slots, self.lora_config,
self.model.config)) packed_moduled_lst, self.model.config))
# (yard1): TODO make this more robust # (yard1): TODO make this more robust
if "lm_head" in module_name: if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule( logits_processor_module = self.model.get_submodule(
...@@ -510,8 +511,10 @@ class LoRAModelManager: ...@@ -510,8 +511,10 @@ class LoRAModelManager:
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".") parts = module_full_name.split(".")
module_name = parts[-1] module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name) replacements = self.packed_modules_mapping.get(module_name, [])
if not replacements: # When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return return
prefix = ".".join(parts[:-1]) prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [ self.packed_modules[module_full_name] = [
......
...@@ -26,6 +26,7 @@ from torch import nn ...@@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -282,11 +283,30 @@ class BaiChuanModel(nn.Module): ...@@ -282,11 +283,30 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module): class BaiChuanBaseForCausalLM(nn.Module):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, def __init__(
self,
config, config,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
...@@ -371,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -371,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B.""" """Baichuan 13B and Baichuan2 7B/13B."""
def __init__(self, def __init__(
self,
config, config,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method) super().__init__(config, "ROPE", linear_method, lora_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method) super().__init__(config, "ALIBI", linear_method, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B.""" """Baichuan 7B."""
def __init__(self, def __init__(
self,
config, config,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None,
super().__init__(config, "ROPE", linear_method) lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", linear_method, lora_config)
...@@ -9,6 +9,7 @@ from torch import nn ...@@ -9,6 +9,7 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -317,11 +318,25 @@ class ChatGLMModel(nn.Module): ...@@ -317,11 +318,25 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module): class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment