Unverified Commit 1700c543 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix LoRA weight sharding (#10450)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 17d8fc18
...@@ -230,7 +230,7 @@ steps: ...@@ -230,7 +230,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/lora - vllm/lora
- tests/lora - tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py
parallelism: 4 parallelism: 4
- label: "PyTorch Fullgraph Smoke Test" # 9min - label: "PyTorch Fullgraph Smoke Test" # 9min
...@@ -475,18 +475,23 @@ steps: ...@@ -475,18 +475,23 @@ steps:
- pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py - pytest -v -s distributed/test_pipeline_parallel.py
- label: LoRA Long Context (Distributed) # 11min - label: LoRA TP Test (Distributed)
# This test runs llama 13B, so it is required to run on 4 GPUs.
num_gpus: 4 num_gpus: 4
soft_fail: true soft_fail: true
source_file_dependencies: source_file_dependencies:
- vllm/lora - vllm/lora
- tests/lora/test_long_context - tests/lora
commands: commands:
# FIXIT: find out which code initialize cuda before running the test # FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it # before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
# This test runs llama 13B, so it is required to run on 4 GPUs.
- pytest -v -s -x lora/test_long_context.py - pytest -v -s -x lora/test_long_context.py
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
......
from typing import List from typing import List
import vllm import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "THUDM/chatglm3-6b" MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
...@@ -20,7 +29,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: ...@@ -20,7 +29,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
), ),
] ]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate( outputs = llm.generate(
prompts, prompts,
...@@ -37,23 +45,58 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: ...@@ -37,23 +45,58 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts return generated_texts
@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files): def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
enable_lora=True, enable_lora=True,
max_loras=4, max_loras=4,
max_lora_rank=64, max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True) trust_remote_code=True)
expected_lora_output = [ output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
"SELECT count(*) FROM singer", for i in range(len(EXPECTED_LORA_OUTPUT)):
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 assert output1[i] == EXPECTED_LORA_OUTPUT[i]
"SELECT name , country , age FROM singer ORDER BY age", output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
] for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)): for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == expected_lora_output[i] assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)): for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == expected_lora_output[i] assert output2[i] == EXPECTED_LORA_OUTPUT[i]
from typing import List from typing import List
import pytest
import ray import ray
import vllm import vllm
from vllm.distributed import cleanup_dist_env_and_memory from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "meta-llama/Llama-2-7b-hf" MODEL_PATH = "meta-llama/Llama-2-7b-hf"
EXPECTED_NO_LORA_OUTPUT = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
EXPECTED_LORA_OUTPUT = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
...@@ -37,88 +55,31 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: ...@@ -37,88 +55,31 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts return generated_texts
@pytest.mark.parametrize("tp_size", [1, 2, 4]) @fork_new_process_for_each_test
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): def test_llama_lora(sql_lora_files):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
enable_lora=True, enable_lora=True,
max_num_seqs=16, max_num_seqs=16,
max_loras=4, max_loras=4,
tensor_parallel_size=tp_size) tensor_parallel_size=1)
expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]
print("lora adapter created") print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1") print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora") print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 2") print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
print("removing lora") print("removing lora")
def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): @fork_new_process_for_each_test
if num_gpus_available < 4:
pytest.skip("Not enough GPUs for tensor parallelism 4")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
del llm_tp1
cleanup_dist_env_and_memory()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2)
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
del llm_tp2
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4)
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
del llm_tp4
cleanup_dist_env_and_memory()
assert output_tp1 == output_tp4
def test_llama_lora_warmup(sql_lora_files): def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and """Test that the LLM initialization works with a warmup LORA path and
is more conservative""" is more conservative"""
...@@ -144,3 +105,57 @@ def test_llama_lora_warmup(sql_lora_files): ...@@ -144,3 +105,57 @@ def test_llama_lora_warmup(sql_lora_files):
"conservative than without lora, therefore the number of " "conservative than without lora, therefore the number of "
"memory blocks for the KV cache should be " "memory blocks for the KV cache should be "
"less when using lora than when not using lora") "less when using lora than when not using lora")
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4(sql_lora_files):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
print("removing lora")
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
print("removing lora")
...@@ -44,6 +44,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -44,6 +44,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
""" """
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
# their `lora_a` and `lora_b` have different sharding patterns. After
# completing the `lora_a` GEMM , a gather operation is performed.
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2] shard_size = self.lora_a_stacked.shape[2]
......
...@@ -451,6 +451,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -451,6 +451,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None: def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__() super().__init__()
# The base_layer type is ColumnParallelLinear or
# MergedColumnParallelLinear, their weight sharding logic is
# inconsistent when TP is greater than 1.
self.is_merged_col_linear = type(
base_layer) is MergedColumnParallelLinear
self.base_layer = base_layer self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size = self.base_layer.input_size self.input_size = self.base_layer.input_size
...@@ -508,14 +514,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -508,14 +514,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return lora_a return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() # Applicable to cases where the base_layer is
shard_size = self.output_dim # MergedColumnParallelLinear.
start_idx = tensor_model_parallel_rank * shard_size if self.is_merged_col_linear:
end_idx = (tensor_model_parallel_rank + 1) * shard_size tp_rank = get_tensor_model_parallel_rank()
lora_b = lora_b[:, start_idx:end_idx] shard_size = self.output_size // 2
offset = lora_b.shape[-1] // 2
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
shard_size]
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size]
lora_b = torch.cat([left_weight, right_weight], dim=1)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
# TODO: Fix the slicing logic of bias.
if bias is None: if bias is None:
return bias return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
...@@ -779,7 +801,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -779,7 +801,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
""" """
ColumnParallelLinear layer that is specifically designed for ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b, qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer. only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b During inference with Tensor Parallel, the weights of lora_b
......
...@@ -760,7 +760,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -760,7 +760,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
# Initialize VL # Initialize VL
if hasattr(config, "visual"): if hasattr(config, "visual"):
return ChatGLM(vllm_config=vllm_config, prefix=prefix) return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
# Initialize LLM # Initialize LLM
else: else:
return ChatGLMV(vllm_config=vllm_config, prefix=prefix) return ChatGLM(vllm_config=vllm_config, prefix=prefix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment