"vscode:/vscode.git/clone" did not exist on "379689d533642cfc1d3ab2cf4dc02f09a8318a5f"
Unverified Commit 8af890a8 authored by Jee Li's avatar Jee Li Committed by GitHub
Browse files

Enable more models to inference based on LoRA (#3382)


Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent dfeb2ecc
......@@ -16,10 +16,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1536) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2304) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
......@@ -27,10 +30,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4608) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
......@@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
......
......@@ -134,6 +134,16 @@ def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
@pytest.fixture(scope="session")
def chatglm3_lora_files():
return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
@pytest.fixture(scope="session")
def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
......
import pytest
import vllm
from vllm.lora.request import LoRARequest
from .conftest import cleanup
MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age ASC",
]
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(baichuan_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
del llm_tp1
cleanup()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
del llm_tp2
cleanup()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
del llm_tp4
cleanup()
assert output_tp1 == output_tp4
\ No newline at end of file
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
......@@ -8,12 +8,16 @@ import torch
import torch.nn.functional as F
from vllm.config import LoRAConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
convert_mapping)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -93,8 +97,7 @@ def populate_loras(
lora_dict: Dict[int, LoRALayerWeights] = dict()
# Dictionary that maps the lora ID to the
# corresponding subloras. Only useful when
# repeats > 1.
# corresponding subloras.
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
for slot_idx, lora_id in enumerate(id_to_index):
......@@ -607,7 +610,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [2, 3])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
......@@ -623,6 +626,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
elif repeats == 3:
linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedQKVParallelLinearWithLora(linear)
else:
linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
......
......@@ -43,9 +43,10 @@ def _lora_ref_impl(
H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
22016, 24576, 32000, 32256, 32512, 32768, 33024
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
32768, 33024
]
SEED = [0xabcdabcd987]
......
# pylint: disable=unused-argument
import inspect
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
......@@ -114,8 +115,11 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
model_config: PretrainedConfig) -> None:
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
"""Initializes lora matrices."""
...
......@@ -144,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module):
"""Sets the mapping indices."""
...
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
......@@ -278,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is VocabParallelEmbedding
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
def create_lora_weights(
self,
......@@ -309,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1]
self.output_dim = self.lora_b_stacked.shape[2]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
......@@ -323,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
......@@ -383,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def linear_weights(self):
return self.base_layer.linear_weights
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1)
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
......@@ -485,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list) == 2
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
self.q_proj_total_size = (self.base_layer.total_num_heads *
self.base_layer.head_size)
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
......@@ -654,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
)
return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 3
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
......@@ -780,6 +890,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def weight(self):
return self.base_layer.weight
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is RowParallelLinear
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
......@@ -900,7 +1016,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
hidden_states: torch.Tensor,
embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
......@@ -949,22 +1065,30 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
supported_layer_types = {
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLora,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
cls
for cls in globals().values() if inspect.isclass(cls)
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
......
......@@ -413,11 +413,12 @@ class LoRAModelManager:
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name):
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
self.model.config))
packed_moduled_lst, self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule(
......@@ -510,8 +511,10 @@ class LoRAModelManager:
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name)
if not replacements:
replacements = self.packed_modules_mapping.get(module_name, [])
# When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
......
......@@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -282,11 +283,30 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
def __init__(
self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
......@@ -371,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B."""
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method)
super().__init__(config, "ROPE", linear_method, lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method)
super().__init__(config, "ALIBI", linear_method, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B."""
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", linear_method, lora_config)
......@@ -9,6 +9,7 @@ from torch import nn
from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
......@@ -317,11 +318,25 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment