"docs/source/features/prompt_embeds.md" did not exist on "221cfc2feaf6ad60b59882249de80eeba66446a6"
Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
......@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention(
block_size = value_cache.shape[3]
num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist()
seq_lens = seq_lens.cpu().tolist()
block_tables_lst = block_tables.cpu().tolist()
seq_lens_lst = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
seq_len = int(seq_lens[i])
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys = []
values = []
keys_lst: List[torch.Tensor] = []
values_lst: List[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size)
keys.append(k)
keys_lst.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
......@@ -432,7 +432,7 @@ def test_varlen_blocksparse_attention_prefill(
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
cu_seq_lens.tolist(),
query,
key,
value,
......
import random
from typing import Tuple
from typing import List, Tuple
import pytest
import torch
......@@ -63,7 +63,7 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = []
block_mapping: List[Tuple[int, int]] = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
......@@ -131,8 +131,8 @@ def test_reshape_and_cache(
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
_, key, value = qkv.unbind(dim=1)
......@@ -170,12 +170,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist()
block_indicies_lst = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
......@@ -224,8 +224,10 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens,
3,
......@@ -257,13 +259,13 @@ def test_reshape_and_cache_flash(
slot_mapping, kv_cache_dtype)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies_lst = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
......
......@@ -17,13 +17,13 @@ capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.tensor):
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor):
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......
......@@ -25,7 +25,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs = []
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
......@@ -70,7 +70,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]],
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
......
from itertools import accumulate, product
from typing import List, Optional
from typing import Dict, List, Optional
import pytest
import torch
......@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
query,
key,
offsets=torch.zeros(batch_size * seq_len,
dtype=int,
dtype=torch.long,
device=device))
# Compare the results.
assert torch.allclose(out_query,
......@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = [
None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
}
]
settings = [
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
ROPE_SCALINGS = (None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
......
......@@ -2,6 +2,7 @@ import contextlib
import gc
import tempfile
from collections import OrderedDict
from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
......@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
LONG_LORA_INFOS = [{
class ContextIDInfo(TypedDict):
lora_id: int
context_length: str
class ContextInfo(TypedDict):
lora: str
context_length: str
LONG_LORA_INFOS: List[ContextIDInfo] = [{
"lora_id": 1,
"context_length": "16k",
}, {
......@@ -207,7 +219,7 @@ def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup()
infos = {}
infos: Dict[int, ContextInfo] = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1:
......@@ -226,7 +238,7 @@ def long_context_infos(long_context_lora_files_16k_1,
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
def llama_2_7b_engine_extra_embeddings():
cleanup()
get_model_old = get_model
......@@ -244,7 +256,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
@pytest.fixture
def llama_2_7b_model_extra_embeddings(
llama_2_7b_engine_extra_embeddings) -> nn.Module:
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)
This diff is collapsed.
from typing import List
import pytest
import vllm
......@@ -10,7 +12,7 @@ MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
......@@ -30,7 +32,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
......
from typing import List
import vllm
from vllm.lora.request import LoRARequest
......@@ -6,7 +8,7 @@ MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
......@@ -26,7 +28,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
......
from typing import List
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "google/gemma-7b"
def do_sample(llm, lora_path: str, lora_id: int) -> str:
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"Quote: Imagination is",
"Quote: Be yourself;",
......@@ -17,7 +19,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
......
......@@ -26,7 +26,7 @@ def get_lora_model(model_id: str, target_modules: List[str], rank: int):
return lora_model
def do_sample(llm,
def do_sample(llm: vllm.LLM,
lora_path: Optional[str] = None,
lora_id: Optional[int] = None,
logprobs: int = 0,
......@@ -42,8 +42,8 @@ def do_sample(llm,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_logprobs = []
generated_texts: List[str] = []
generated_logprobs: List[List[List[int]]] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
......
......@@ -109,7 +109,7 @@ def populate_loras(
for slot_idx, lora_id in enumerate(id_to_index):
if lora_id is not None:
subloras = []
subloras: List[LoRALayerWeights] = []
sublora_len = layer_weights.shape[0] // repeats
for i in range(repeats):
sublora = DummyLoRAManager().init_random_lora(
......@@ -158,7 +158,10 @@ def create_random_inputs(
low, high = input_range
inputs, index_mapping, prompt_mapping = [], [], []
inputs: List[torch.Tensor] = []
index_mapping: List[int] = []
prompt_mapping: List[int] = []
for _ in range(num_inputs):
if input_type == torch.int:
inputs.append(
......@@ -222,7 +225,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
lora_result = lora_embedding(torch.cat(inputs))
expected_results = []
expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = embedding(input_)
......@@ -356,7 +359,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
lora_result = lora_embedding(torch.cat(original_inputs))
expected_results = []
expected_results: List[torch.Tensor] = []
for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping):
lora = lora_dict[lora_id]
......@@ -482,7 +485,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
logits_processor.org_vocab_size = (vocab_size +
lora_config.lora_extra_vocab_size)
expected_results = []
expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_,
......@@ -598,7 +601,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
lora_result = lora_linear(torch.cat(inputs))[0]
expected_results = []
expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
......@@ -729,7 +732,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_result = lora_linear(torch.cat(inputs))[0]
expected_results = []
expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
......@@ -885,9 +888,9 @@ def test_vocab_parallel_embedding_indices(tp_size, seed):
computed_added_vocab_size = 0
vocab_size_padded = -1
all_org_tokens = []
all_added_tokens = []
token_ids = []
all_org_tokens: List[int] = []
all_added_tokens: List[int] = []
token_ids: List[int] = []
for tp_rank in range(tp_size):
with patch(
......
from typing import List
import pytest
import ray
......@@ -9,7 +11,7 @@ from .conftest import cleanup
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
def do_sample(llm, lora_path: str, lora_id: int):
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
......@@ -27,7 +29,7 @@ def do_sample(llm, lora_path: str, lora_id: int):
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
......
......@@ -77,7 +77,7 @@ def evaluate_json_response(model_response, golden_response):
def generate(
llm,
llm: vllm.LLM,
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
):
prompts, sampling_param, lora_request = inputs
......@@ -159,7 +159,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
non-batched generation.
"""
# Create non batched results first to compare against batched results
non_batched_results = []
non_batched_results: List[str] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
......@@ -172,7 +172,8 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
# Create batched results
# Each element of the batch must be
# (prompt, prompt_sampling_params, prompt_lora_request)
batched_prompts = []
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
......@@ -196,7 +197,8 @@ def test_self_consistency(lora_llm, long_context_infos):
num_loras = len(long_context_infos)
# Create results in order of long_context_infos
batched_prompts = []
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
......@@ -244,7 +246,7 @@ def test_quality(lora_llm, long_context_infos):
The test is expected to run for about 1 minute on a p4de.24xlarge
instance.
"""
scores = []
scores: List[float] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
for prompt_and_response in prompts_and_responses[context_len]:
......@@ -277,7 +279,8 @@ def test_max_len(lora_llm, long_context_infos):
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
# Also test batched
batched_prompts = []
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id_with_bad_inputs in long_context_infos:
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
......
from typing import List
import pytest
from vllm.lora.models import LoRAModel
......@@ -17,7 +19,7 @@ def test_load_checkpoints(
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules = []
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
......
import os
from typing import List
from typing import Dict, List
import pytest
import torch
......@@ -62,7 +62,7 @@ def test_from_lora_tensors(sql_lora_files):
def create_lora(lora_id: int, model: nn.Module,
sub_modules: List[str]) -> LoRAModel:
loras = {}
loras: Dict[str, LoRALayerWeights] = {}
for name in sub_modules:
w = model.get_submodule(name).weight
loras[name] = LoRALayerWeights(
......@@ -83,7 +83,7 @@ def create_packed_lora(
empty_replaced_module_name=None,
) -> LoRAModel:
w = model.get_submodule(module_name).weight
loras = {}
loras: Dict[str, LoRALayerWeights] = {}
for replaced_module_name in replaced_module_names:
if replaced_module_name == empty_replaced_module_name:
continue
......
from typing import List
import pytest
import torch
......@@ -7,7 +9,7 @@ from vllm.lora.request import LoRARequest
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
def do_sample(llm, lora_path: str, lora_id: int):
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
......@@ -20,7 +22,7 @@ def do_sample(llm, lora_path: str, lora_id: int):
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
......
from typing import List
import vllm
from vllm.lora.request import LoRARequest
......@@ -6,7 +8,7 @@ MODEL_PATH = "microsoft/phi-2"
PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(
sql_prompt=
......@@ -35,7 +37,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
if lora_id else None,
)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
......
......@@ -25,7 +25,10 @@ MODELS: List[ModelWithQuantization] = [
]
def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
def do_sample(llm: vllm.LLM,
lora_path: str,
lora_id: int,
max_tokens: int = 256) -> List[str]:
raw_prompts = [
"Give me an orange-ish brown color",
"Give me a neon pink color",
......@@ -45,7 +48,7 @@ def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
......
from typing import List, Optional
from typing import Dict, List, Optional
import torch
......@@ -9,13 +9,13 @@ class DummyLoRAManager:
def __init__(self):
super().__init__()
self._loras = {}
self._loras: Dict[str, LoRALayerWeights] = {}
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
return self._loras.get(module_name, None)
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
return self._loras[module_name]
def init_random_lora(self,
module_name: str,
......@@ -68,11 +68,11 @@ class DummyLoRAManager:
module_name: str,
input_dim: int,
output_dims: List[int],
noop_lora_index: List[int] = None,
rank=8,
noop_lora_index: Optional[List[int]] = None,
rank: int = 8,
):
base_loras = []
noop_lora_index = set(noop_lora_index or [])
base_loras: List[LoRALayerWeights] = []
noop_lora_index_set = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims):
base_lora = self.init_lora(
......@@ -80,7 +80,7 @@ class DummyLoRAManager:
input_dim,
out_dim,
rank=rank,
noop=i in noop_lora_index,
noop=i in noop_lora_index_set,
)
base_loras.append(base_lora)
packed_lora = PackedLoRALayerWeights.pack(base_loras)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment