Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
...@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention( ...@@ -77,27 +77,27 @@ def ref_single_query_cached_kv_attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs = query.shape[0] num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist() block_tables_lst = block_tables.cpu().tolist()
seq_lens = seq_lens.cpu().tolist() seq_lens_lst = seq_lens.cpu().tolist()
for i in range(num_seqs): for i in range(num_seqs):
q = query[i].unsqueeze(0) q = query[i].unsqueeze(0)
block_table = block_tables[i] block_table = block_tables_lst[i]
seq_len = int(seq_lens[i]) seq_len = int(seq_lens_lst[i])
keys = [] keys_lst: List[torch.Tensor] = []
values = [] values_lst: List[torch.Tensor] = []
for j in range(seq_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :] k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size) k = k.reshape(num_kv_heads, head_size)
keys.append(k) keys_lst.append(k)
v = value_cache[block_number, :, :, block_offset] v = value_cache[block_number, :, :, block_offset]
values.append(v) values_lst.append(v)
keys = torch.stack(keys, dim=0) keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1: if num_queries_per_kv > 1:
# Handle MQA and GQA # Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
...@@ -432,7 +432,7 @@ def test_varlen_blocksparse_attention_prefill( ...@@ -432,7 +432,7 @@ def test_varlen_blocksparse_attention_prefill(
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
ref_output = ref_multi_query_kv_attention( ref_output = ref_multi_query_kv_attention(
cu_seq_lens, cu_seq_lens.tolist(),
query, query,
key, key,
value, value,
......
import random import random
from typing import Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
...@@ -63,7 +63,7 @@ def test_copy_blocks( ...@@ -63,7 +63,7 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = [] block_mapping: List[Tuple[int, int]] = []
for i in range(num_mappings): for i in range(num_mappings):
src = src_blocks[i] src = src_blocks[i]
dst1 = dst_blocks[2 * i] dst1 = dst_blocks[2 * i]
...@@ -131,8 +131,8 @@ def test_reshape_and_cache( ...@@ -131,8 +131,8 @@ def test_reshape_and_cache(
torch.set_default_device(device) torch.set_default_device(device)
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long) slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype) qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
...@@ -170,12 +170,12 @@ def test_reshape_and_cache( ...@@ -170,12 +170,12 @@ def test_reshape_and_cache(
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist() block_indicies_lst = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist() block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens): for i in range(num_tokens):
block_idx = block_indicies[i] block_idx = block_indicies_lst[i]
block_offset = block_offsets[i] block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
...@@ -224,8 +224,10 @@ def test_reshape_and_cache_flash( ...@@ -224,8 +224,10 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens, qkv = torch.randn(num_tokens,
3, 3,
...@@ -257,13 +259,13 @@ def test_reshape_and_cache_flash( ...@@ -257,13 +259,13 @@ def test_reshape_and_cache_flash(
slot_mapping, kv_cache_dtype) slot_mapping, kv_cache_dtype)
# Run the reference implementation. # Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist() block_indicies_lst = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist() block_offsets_lst = block_offsets.cpu().tolist()
for i in range(num_tokens): for i in range(num_tokens):
block_idx = block_indicies[i] block_idx = block_indicies_lst[i]
block_offset = block_offsets[i] block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i] cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i] cloned_value_cache[block_idx, block_offset, :, :] = value[i]
......
...@@ -17,13 +17,13 @@ capability = torch.cuda.get_device_capability() ...@@ -17,13 +17,13 @@ capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.tensor): def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor): def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......
...@@ -25,7 +25,7 @@ def ref_paged_attn( ...@@ -25,7 +25,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy() block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape _, block_size, num_kv_heads, head_size = key_cache.shape
outputs = [] outputs: List[torch.Tensor] = []
start_idx = 0 start_idx = 0
for i in range(num_seqs): for i in range(num_seqs):
query_len = query_lens[i] query_len = query_lens[i]
...@@ -70,7 +70,7 @@ def ref_paged_attn( ...@@ -70,7 +70,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode @torch.inference_mode
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]], kv_lens: List[int],
num_heads: Tuple[int, int], num_heads: Tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
......
from itertools import accumulate, product from itertools import accumulate, product
from typing import List, Optional from typing import Dict, List, Optional
import pytest import pytest
import torch import torch
...@@ -126,7 +126,7 @@ def test_batched_rotary_embedding( ...@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
query, query,
key, key,
offsets=torch.zeros(batch_size * seq_len, offsets=torch.zeros(batch_size * seq_len,
dtype=int, dtype=torch.long,
device=device)) device=device))
# Compare the results. # Compare the results.
assert torch.allclose(out_query, assert torch.allclose(out_query,
...@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
def test_rope_module_cache(): def test_rope_module_cache():
MAX_POSITIONS = [123, 1234] MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000] BASES = [10000, 1000000]
ROPE_SCALINGS = [ ROPE_SCALINGS = (None, {
None, {
"type": "linear", "type": "linear",
"factor": (1, ) "factor": (1, )
}, { }, {
"type": "dynamic", "type": "dynamic",
"factor": 1 "factor": 1
} })
] settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
settings = [ ROPE_SCALINGS, DTYPES)
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, rope_setting_id_map: Dict[str, int] = {}
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
for setting in product(*settings): for setting in product(*settings):
head_size, rotary_dim, max_position, base, \ head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting is_neox_stype, rope_scaling, dtype = setting
......
...@@ -2,6 +2,7 @@ import contextlib ...@@ -2,6 +2,7 @@ import contextlib
import gc import gc
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
...@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -24,7 +25,18 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
LONG_LORA_INFOS = [{
class ContextIDInfo(TypedDict):
lora_id: int
context_length: str
class ContextInfo(TypedDict):
lora: str
context_length: str
LONG_LORA_INFOS: List[ContextIDInfo] = [{
"lora_id": 1, "lora_id": 1,
"context_length": "16k", "context_length": "16k",
}, { }, {
...@@ -207,7 +219,7 @@ def long_context_infos(long_context_lora_files_16k_1, ...@@ -207,7 +219,7 @@ def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2, long_context_lora_files_16k_2,
long_context_lora_files_32k): long_context_lora_files_32k):
cleanup() cleanup()
infos = {} infos: Dict[int, ContextInfo] = {}
for lora_checkpoint_info in LONG_LORA_INFOS: for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"] lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1: if lora_id == 1:
...@@ -226,7 +238,7 @@ def long_context_infos(long_context_lora_files_16k_1, ...@@ -226,7 +238,7 @@ def long_context_infos(long_context_lora_files_16k_1,
@pytest.fixture @pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module: def llama_2_7b_engine_extra_embeddings():
cleanup() cleanup()
get_model_old = get_model get_model_old = get_model
...@@ -244,7 +256,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: ...@@ -244,7 +256,6 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
@pytest.fixture @pytest.fixture
def llama_2_7b_model_extra_embeddings( def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
llama_2_7b_engine_extra_embeddings) -> nn.Module:
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model) model_runner.model)
This diff is collapsed.
from typing import List
import pytest import pytest
import vllm import vllm
...@@ -10,7 +12,7 @@ MODEL_PATH = "baichuan-inc/Baichuan-7B" ...@@ -10,7 +12,7 @@ MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format( PROMPT_TEMPLATE.format(
...@@ -30,7 +32,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: ...@@ -30,7 +32,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text.strip() generated_text = output.outputs[0].text.strip()
......
from typing import List
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -6,7 +8,7 @@ MODEL_PATH = "THUDM/chatglm3-6b" ...@@ -6,7 +8,7 @@ MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format( PROMPT_TEMPLATE.format(
...@@ -26,7 +28,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: ...@@ -26,7 +28,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text.strip() generated_text = output.outputs[0].text.strip()
......
from typing import List
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
MODEL_PATH = "google/gemma-7b" MODEL_PATH = "google/gemma-7b"
def do_sample(llm, lora_path: str, lora_id: int) -> str: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
"Quote: Imagination is", "Quote: Imagination is",
"Quote: Be yourself;", "Quote: Be yourself;",
...@@ -17,7 +19,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: ...@@ -17,7 +19,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text.strip() generated_text = output.outputs[0].text.strip()
......
...@@ -26,7 +26,7 @@ def get_lora_model(model_id: str, target_modules: List[str], rank: int): ...@@ -26,7 +26,7 @@ def get_lora_model(model_id: str, target_modules: List[str], rank: int):
return lora_model return lora_model
def do_sample(llm, def do_sample(llm: vllm.LLM,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
lora_id: Optional[int] = None, lora_id: Optional[int] = None,
logprobs: int = 0, logprobs: int = 0,
...@@ -42,8 +42,8 @@ def do_sample(llm, ...@@ -42,8 +42,8 @@ def do_sample(llm,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
generated_logprobs = [] generated_logprobs: List[List[List[int]]] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
......
...@@ -109,7 +109,7 @@ def populate_loras( ...@@ -109,7 +109,7 @@ def populate_loras(
for slot_idx, lora_id in enumerate(id_to_index): for slot_idx, lora_id in enumerate(id_to_index):
if lora_id is not None: if lora_id is not None:
subloras = [] subloras: List[LoRALayerWeights] = []
sublora_len = layer_weights.shape[0] // repeats sublora_len = layer_weights.shape[0] // repeats
for i in range(repeats): for i in range(repeats):
sublora = DummyLoRAManager().init_random_lora( sublora = DummyLoRAManager().init_random_lora(
...@@ -158,7 +158,10 @@ def create_random_inputs( ...@@ -158,7 +158,10 @@ def create_random_inputs(
low, high = input_range low, high = input_range
inputs, index_mapping, prompt_mapping = [], [], [] inputs: List[torch.Tensor] = []
index_mapping: List[int] = []
prompt_mapping: List[int] = []
for _ in range(num_inputs): for _ in range(num_inputs):
if input_type == torch.int: if input_type == torch.int:
inputs.append( inputs.append(
...@@ -222,7 +225,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -222,7 +225,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
expected_results = [] expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
result = embedding(input_) result = embedding(input_)
...@@ -356,7 +359,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -356,7 +359,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
lora_result = lora_embedding(torch.cat(original_inputs)) lora_result = lora_embedding(torch.cat(original_inputs))
expected_results = [] expected_results: List[torch.Tensor] = []
for input_, original_input_, lora_id in zip(inputs, original_inputs, for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping): prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
...@@ -482,7 +485,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -482,7 +485,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
logits_processor.org_vocab_size = (vocab_size + logits_processor.org_vocab_size = (vocab_size +
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
expected_results = [] expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_, result = logits_processor._get_logits(hidden_states=input_,
...@@ -598,7 +601,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -598,7 +601,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
expected_results = [] expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
result = linear(input_)[0] result = linear(input_)[0]
...@@ -729,7 +732,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -729,7 +732,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
expected_results = [] expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0] result = linear(input_)[0]
subloras = sublora_dict[lora_id] subloras = sublora_dict[lora_id]
...@@ -885,9 +888,9 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): ...@@ -885,9 +888,9 @@ def test_vocab_parallel_embedding_indices(tp_size, seed):
computed_added_vocab_size = 0 computed_added_vocab_size = 0
vocab_size_padded = -1 vocab_size_padded = -1
all_org_tokens = [] all_org_tokens: List[int] = []
all_added_tokens = [] all_added_tokens: List[int] = []
token_ids = [] token_ids: List[int] = []
for tp_rank in range(tp_size): for tp_rank in range(tp_size):
with patch( with patch(
......
from typing import List
import pytest import pytest
import ray import ray
...@@ -9,7 +11,7 @@ from .conftest import cleanup ...@@ -9,7 +11,7 @@ from .conftest import cleanup
MODEL_PATH = "meta-llama/Llama-2-7b-hf" MODEL_PATH = "meta-llama/Llama-2-7b-hf"
def do_sample(llm, lora_path: str, lora_id: int): def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
...@@ -27,7 +29,7 @@ def do_sample(llm, lora_path: str, lora_id: int): ...@@ -27,7 +29,7 @@ def do_sample(llm, lora_path: str, lora_id: int):
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
......
...@@ -77,7 +77,7 @@ def evaluate_json_response(model_response, golden_response): ...@@ -77,7 +77,7 @@ def evaluate_json_response(model_response, golden_response):
def generate( def generate(
llm, llm: vllm.LLM,
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
): ):
prompts, sampling_param, lora_request = inputs prompts, sampling_param, lora_request = inputs
...@@ -159,7 +159,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos): ...@@ -159,7 +159,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
non-batched generation. non-batched generation.
""" """
# Create non batched results first to compare against batched results # Create non batched results first to compare against batched results
non_batched_results = [] non_batched_results: List[str] = []
for lora_id, info in long_context_infos.items(): for lora_id, info in long_context_infos.items():
context_len = info["context_length"] context_len = info["context_length"]
...@@ -172,7 +172,8 @@ def test_batched_rope_kernel(lora_llm, long_context_infos): ...@@ -172,7 +172,8 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
# Create batched results # Create batched results
# Each element of the batch must be # Each element of the batch must be
# (prompt, prompt_sampling_params, prompt_lora_request) # (prompt, prompt_sampling_params, prompt_lora_request)
batched_prompts = [] batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items(): for lora_id, info in long_context_infos.items():
context_len = info["context_length"] context_len = info["context_length"]
batched_prompts.extend([ batched_prompts.extend([
...@@ -196,7 +197,8 @@ def test_self_consistency(lora_llm, long_context_infos): ...@@ -196,7 +197,8 @@ def test_self_consistency(lora_llm, long_context_infos):
num_loras = len(long_context_infos) num_loras = len(long_context_infos)
# Create results in order of long_context_infos # Create results in order of long_context_infos
batched_prompts = [] batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items(): for lora_id, info in long_context_infos.items():
context_len = info["context_length"] context_len = info["context_length"]
batched_prompts.extend([ batched_prompts.extend([
...@@ -244,7 +246,7 @@ def test_quality(lora_llm, long_context_infos): ...@@ -244,7 +246,7 @@ def test_quality(lora_llm, long_context_infos):
The test is expected to run for about 1 minute on a p4de.24xlarge The test is expected to run for about 1 minute on a p4de.24xlarge
instance. instance.
""" """
scores = [] scores: List[float] = []
for lora_id, info in long_context_infos.items(): for lora_id, info in long_context_infos.items():
context_len = info["context_length"] context_len = info["context_length"]
for prompt_and_response in prompts_and_responses[context_len]: for prompt_and_response in prompts_and_responses[context_len]:
...@@ -277,7 +279,8 @@ def test_max_len(lora_llm, long_context_infos): ...@@ -277,7 +279,8 @@ def test_max_len(lora_llm, long_context_infos):
generate(lora_llm, (bad_prompt, sampling_params, lora_request)) generate(lora_llm, (bad_prompt, sampling_params, lora_request))
# Also test batched # Also test batched
batched_prompts = [] batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id_with_bad_inputs in long_context_infos: for lora_id_with_bad_inputs in long_context_infos:
for lora_id, info in long_context_infos.items(): for lora_id, info in long_context_infos.items():
context_len = info["context_length"] context_len = info["context_length"]
......
from typing import List
import pytest import pytest
from vllm.lora.models import LoRAModel from vllm.lora.models import LoRAModel
...@@ -17,7 +19,7 @@ def test_load_checkpoints( ...@@ -17,7 +19,7 @@ def test_load_checkpoints(
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_modules.extend(packed_modules_mapping[module])
......
import os import os
from typing import List from typing import Dict, List
import pytest import pytest
import torch import torch
...@@ -62,7 +62,7 @@ def test_from_lora_tensors(sql_lora_files): ...@@ -62,7 +62,7 @@ def test_from_lora_tensors(sql_lora_files):
def create_lora(lora_id: int, model: nn.Module, def create_lora(lora_id: int, model: nn.Module,
sub_modules: List[str]) -> LoRAModel: sub_modules: List[str]) -> LoRAModel:
loras = {} loras: Dict[str, LoRALayerWeights] = {}
for name in sub_modules: for name in sub_modules:
w = model.get_submodule(name).weight w = model.get_submodule(name).weight
loras[name] = LoRALayerWeights( loras[name] = LoRALayerWeights(
...@@ -83,7 +83,7 @@ def create_packed_lora( ...@@ -83,7 +83,7 @@ def create_packed_lora(
empty_replaced_module_name=None, empty_replaced_module_name=None,
) -> LoRAModel: ) -> LoRAModel:
w = model.get_submodule(module_name).weight w = model.get_submodule(module_name).weight
loras = {} loras: Dict[str, LoRALayerWeights] = {}
for replaced_module_name in replaced_module_names: for replaced_module_name in replaced_module_names:
if replaced_module_name == empty_replaced_module_name: if replaced_module_name == empty_replaced_module_name:
continue continue
......
from typing import List
import pytest import pytest
import torch import torch
...@@ -7,7 +9,7 @@ from vllm.lora.request import LoRARequest ...@@ -7,7 +9,7 @@ from vllm.lora.request import LoRARequest
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
def do_sample(llm, lora_path: str, lora_id: int): def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501 "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501 "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
...@@ -20,7 +22,7 @@ def do_sample(llm, lora_path: str, lora_id: int): ...@@ -20,7 +22,7 @@ def do_sample(llm, lora_path: str, lora_id: int):
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text.strip() generated_text = output.outputs[0].text.strip()
......
from typing import List
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -6,7 +8,7 @@ MODEL_PATH = "microsoft/phi-2" ...@@ -6,7 +8,7 @@ MODEL_PATH = "microsoft/phi-2"
PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str: def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [ prompts = [
PROMPT_TEMPLATE.format( PROMPT_TEMPLATE.format(
sql_prompt= sql_prompt=
...@@ -35,7 +37,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: ...@@ -35,7 +37,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str:
if lora_id else None, if lora_id else None,
) )
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text.strip() generated_text = output.outputs[0].text.strip()
......
...@@ -25,7 +25,10 @@ MODELS: List[ModelWithQuantization] = [ ...@@ -25,7 +25,10 @@ MODELS: List[ModelWithQuantization] = [
] ]
def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): def do_sample(llm: vllm.LLM,
lora_path: str,
lora_id: int,
max_tokens: int = 256) -> List[str]:
raw_prompts = [ raw_prompts = [
"Give me an orange-ish brown color", "Give me an orange-ish brown color",
"Give me a neon pink color", "Give me a neon pink color",
...@@ -45,7 +48,7 @@ def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): ...@@ -45,7 +48,7 @@ def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None) if lora_id else None)
# Print the outputs. # Print the outputs.
generated_texts = [] generated_texts: List[str] = []
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
......
from typing import List, Optional from typing import Dict, List, Optional
import torch import torch
...@@ -9,13 +9,13 @@ class DummyLoRAManager: ...@@ -9,13 +9,13 @@ class DummyLoRAManager:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._loras = {} self._loras: Dict[str, LoRALayerWeights] = {}
def set_module_lora(self, module_name: str, lora: LoRALayerWeights): def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora self._loras[module_name] = lora
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]: def get_module_lora(self, module_name: str) -> LoRALayerWeights:
return self._loras.get(module_name, None) return self._loras[module_name]
def init_random_lora(self, def init_random_lora(self,
module_name: str, module_name: str,
...@@ -68,11 +68,11 @@ class DummyLoRAManager: ...@@ -68,11 +68,11 @@ class DummyLoRAManager:
module_name: str, module_name: str,
input_dim: int, input_dim: int,
output_dims: List[int], output_dims: List[int],
noop_lora_index: List[int] = None, noop_lora_index: Optional[List[int]] = None,
rank=8, rank: int = 8,
): ):
base_loras = [] base_loras: List[LoRALayerWeights] = []
noop_lora_index = set(noop_lora_index or []) noop_lora_index_set = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims): for i, out_dim in enumerate(output_dims):
base_lora = self.init_lora( base_lora = self.init_lora(
...@@ -80,7 +80,7 @@ class DummyLoRAManager: ...@@ -80,7 +80,7 @@ class DummyLoRAManager:
input_dim, input_dim,
out_dim, out_dim,
rank=rank, rank=rank,
noop=i in noop_lora_index, noop=i in noop_lora_index_set,
) )
base_loras.append(base_lora) base_loras.append(base_lora)
packed_lora = PackedLoRALayerWeights.pack(base_loras) packed_lora = PackedLoRALayerWeights.pack(base_loras)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment