Unverified Commit f12c3b5b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Add Phi-2 LoRA support (#4886)

parent d130b573
...@@ -118,7 +118,7 @@ Alongside each architecture, we include some popular models that use it. ...@@ -118,7 +118,7 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`PhiForCausalLM` * - :code:`PhiForCausalLM`
- Phi - Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
- - ✅︎
* - :code:`Phi3ForCausalLM` * - :code:`Phi3ForCausalLM`
- Phi-3 - Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
......
...@@ -165,6 +165,11 @@ def tinyllama_lora_files(): ...@@ -165,6 +165,11 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture(scope="session")
def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def long_context_lora_files_16k_1(): def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
......
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "microsoft/phi-2"
PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(
sql_prompt=
"Which catalog publisher has published the most catalogs?",
context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"),
PROMPT_TEMPLATE.format(
sql_prompt=
"Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501
context=
"CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501
),
PROMPT_TEMPLATE.format(
sql_prompt=
"How many marine species are found in the Southern Ocean?", # noqa: E501
context=
"CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501
),
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=64,
stop="### End")
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
enforce_eager=True)
expected_lora_output = [
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
"SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501
"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501
]
output1 = do_sample(llm, phi2_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, phi2_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])
...@@ -42,7 +42,7 @@ from torch import nn ...@@ -42,7 +42,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -229,11 +229,32 @@ class PhiModel(nn.Module): ...@@ -229,11 +229,32 @@ class PhiModel(nn.Module):
class PhiForCausalLM(nn.Module): class PhiForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"dense",
"fc1",
"fc2",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, def __init__(
self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment