Unverified Commit 8bddb735 authored by Akshat Tripathi's avatar Akshat Tripathi Committed by GitHub
Browse files

[Hardware][CPU] Multi-LoRA implementation for the CPU backend (#11100)


Signed-off-by: default avatarAkshat Tripathi <akshat@krai.ai>
Signed-off-by: default avatarOleg Mosalov <oleg@krai.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarOleg Mosalov <oleg@krai.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent f967e51f
......@@ -75,6 +75,12 @@ function cpu_tests() {
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
# Run multi-lora tests
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
set -e
pytest -s -v \
tests/lora/test_qwen2vl.py"
}
# All of CPU tests are expected to be finished less than 25 mins.
......
......@@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
- ✅
- ✅
- ✅
- [✗](gh-pr:4830)
-
- ✅
* - <abbr title="Prompt Adapter">prmpt adptr</abbr>
- ✅
......
......@@ -21,6 +21,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.platforms import current_platform
class ContextIDInfo(TypedDict):
......@@ -65,13 +66,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
backend = "nccl"
if current_platform.is_cpu():
backend = "gloo"
init_distributed_environment(world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend=backend)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory(shutdown_ray=True)
......@@ -81,13 +85,15 @@ def dist_init():
def dist_init_torch_only():
if torch.distributed.is_initialized():
return
backend = "nccl"
if current_platform.is_cpu():
backend = "gloo"
temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
torch.distributed.init_process_group(world_size=1,
rank=0,
init_method=f"file://{temp_file}",
backend=backend)
@pytest.fixture
......
......@@ -48,10 +48,14 @@ TOLERANCES = {
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
# TODO: Modify this based on platform
DEVICES = [
pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
reason="Backend not supported")
DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
] if current_platform.is_cuda_alike() else ["cpu"])
#For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
......@@ -198,6 +202,10 @@ def check_punica_wrapper(punica_wrapper) -> bool:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
return type(punica_wrapper) is PunicaWrapperGPU
elif current_platform.is_cpu():
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
return type(punica_wrapper) is PunicaWrapperCPU
else:
return False
......@@ -211,7 +219,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
......@@ -313,7 +322,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size, stage) -> None:
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
......@@ -450,7 +461,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
......@@ -582,7 +595,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
......@@ -695,7 +710,9 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage, bias_enabled) -> None:
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
......@@ -818,7 +835,9 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage, bias_enabled) -> None:
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
......@@ -971,6 +990,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only CUDA backends are supported")
def test_rotary_embedding_long_context(dist_init, num_loras, device,
scaling_factors, max_position,
is_neox_style, rotary_dim, head_size,
......
......@@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
......@@ -28,9 +29,9 @@ EMBEDDING_MODULES = {
EMBEDDING_PADDING_MODULES = ["lm_head"]
CUDA_DEVICES = [
DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
] if current_platform.is_cuda_alike() else ["cpu"])
def test_peft_helper(sql_lora_files):
......@@ -83,7 +84,7 @@ def test_peft_helper(sql_lora_files):
PEFTHelper.from_dict(config)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
......@@ -171,7 +172,7 @@ def test_replace_submodules(dist_init, dummy_model):
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device("cuda"))
torch.device(DEVICES[0]))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
......@@ -183,7 +184,7 @@ def test_replace_submodules(dist_init, dummy_model):
RowParallelLinearWithLoRA)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
......@@ -244,7 +245,7 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.punica_wrapper.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
......@@ -336,7 +337,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
......@@ -466,7 +467,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.device == device
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
......@@ -545,7 +546,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
device)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
......@@ -621,7 +622,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
device)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
......
......@@ -5,6 +5,7 @@ import torch
import vllm
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
......@@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora(mixtral_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
if torch.cuda.device_count(
) < tp_size and tp_size > 1 and current_platform.is_cuda_alike():
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = [
......
......@@ -9,17 +9,16 @@ from threading import Lock
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
import vllm.lora.ops.triton_ops # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices, ref_torch_groupgemm)
generate_data_for_nslices)
HIDDEN_SIZES = [
128,
......@@ -113,7 +112,7 @@ DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [32]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
DEVICES = [f"cuda:{0}"]
_dict_lock = Lock()
......@@ -127,7 +126,7 @@ _dict_lock = Lock()
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
......@@ -174,7 +173,7 @@ def test_punica_sgmv(
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
sgmv_shrink(
torch.ops.vllm.sgmv_shrink(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
......@@ -187,20 +186,23 @@ def test_punica_sgmv(
scaling,
)
for index in range(nslices):
ref_torch_groupgemm(
ref_out_tensor[index],
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
lora_indices_tensor,
ref_out_tensor[index],
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
op_type,
)
else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
sgmv_expand(
torch.ops.vllm.sgmv_expand(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
......@@ -213,21 +215,39 @@ def test_punica_sgmv(
offset_start=0,
add_inputs=True,
)
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
ref_torch_groupgemm(
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
inputs_tensor[index],
lora_weights,
lora_indices_tensor,
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
1.0,
op_type,
max_seq_length,
token_nums,
add_inputs=True,
)
slice_offset += hidden_size
else:
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
assert_close(our_out_tensor, ref_out_tensor)
......@@ -240,7 +260,7 @@ def test_punica_sgmv(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
......@@ -276,31 +296,38 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
bgmv_shrink(
torch.ops.vllm.bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
scaling,
)
else:
bgmv_expand(
torch.ops.vllm.bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
bgmv_expand(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
add_inputs=True,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
......@@ -313,7 +340,7 @@ def test_punica_bgmv(
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv_expand_nslices(
batches: int,
num_loras: int,
......@@ -350,7 +377,7 @@ def test_punica_bgmv_expand_nslices(
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
bgmv_expand_slice(
torch.ops.vllm.bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
......@@ -359,15 +386,14 @@ def test_punica_bgmv_expand_nslices(
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
bgmv_expand_slice(
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
......
......@@ -9,19 +9,18 @@ import pytest
import torch
# Enable custom op register
import vllm.lora.ops.bgmv_expand
import vllm.lora.ops.bgmv_expand_slice
import vllm.lora.ops.bgmv_shrink
import vllm.lora.ops.sgmv_expand
import vllm.lora.ops.sgmv_shrink # noqa: F401
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
import vllm.lora.ops.triton_ops # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices, ref_torch_groupgemm)
generate_data_for_nslices)
HIDDEN_SIZES = [4097]
HIDDEN_SIZES = [2049]
BATCHES = [1, 4, 16, 32]
NUM_LORA = [1, 8, 32, 128]
......@@ -29,15 +28,7 @@ DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
# Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops.
bgmv_expand = torch.ops.vllm.bgmv_expand
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
bgmv_shrink = torch.ops.vllm.bgmv_shrink
sgmv_expand = torch.ops.vllm.sgmv_expand
sgmv_shrink = torch.ops.vllm.sgmv_shrink
DEVICES = [f"cuda:{0}"]
_dict_lock = Lock()
......@@ -51,7 +42,7 @@ _dict_lock = Lock()
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
......@@ -98,7 +89,7 @@ def test_punica_sgmv(
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
sgmv_shrink(
torch.ops.vllm.sgmv_shrink(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
......@@ -111,20 +102,23 @@ def test_punica_sgmv(
scaling,
)
for index in range(nslices):
ref_torch_groupgemm(
ref_out_tensor[index],
sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
lora_indices_tensor,
ref_out_tensor[index],
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
op_type,
)
else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
sgmv_expand(
torch.ops.vllm.sgmv_expand(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
......@@ -137,21 +131,39 @@ def test_punica_sgmv(
offset_start=0,
add_inputs=True,
)
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
ref_torch_groupgemm(
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
inputs_tensor[index],
lora_weights,
lora_indices_tensor,
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
1.0,
op_type,
max_seq_length,
token_nums,
add_inputs=True,
)
slice_offset += hidden_size
else:
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
ref_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
assert_close(our_out_tensor, ref_out_tensor)
......@@ -164,7 +176,7 @@ def test_punica_sgmv(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
......@@ -176,7 +188,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
......@@ -201,32 +212,38 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
bgmv_shrink(
torch.ops.vllm.bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
bgmv_expand(
bgmv_shrink(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
scaling,
)
else:
torch.ops.vllm.bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
bgmv_expand(
inputs_tensor,
lora_weights,
ref_out_tensor,
indices,
add_inputs=True,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
......@@ -239,7 +256,7 @@ def test_punica_bgmv(
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_bgmv_expand_nslices(
batches: int,
num_loras: int,
......@@ -276,7 +293,7 @@ def test_punica_bgmv_expand_nslices(
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
bgmv_expand_slice(
torch.ops.vllm.bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
......@@ -285,15 +302,14 @@ def test_punica_bgmv_expand_nslices(
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
bgmv_expand_slice(
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
slice_offset += hidden_size
......
......@@ -72,7 +72,8 @@ def do_sample(llm: vllm.LLM,
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tp_size):
if num_gpus_available < tp_size:
if num_gpus_available < tp_size and \
tp_size > 1 and current_platform.is_cuda_alike():
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(
......
......@@ -104,33 +104,6 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def ref_torch_groupgemm(
out_tensor,
inputs,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
) -> torch.Tensor:
out_list = []
current_offset = 0
for lora_index, b_length in zip(range(batches), seq_len_tensor):
input_weight = inputs[current_offset:b_length + current_offset, :]
current_offset += b_length
lora_weight = lora_weights[lora_indices_tensor[lora_index]]
result = torch.nn.functional.linear(input_weight, lora_weight)
result *= scaling
out_list.append(result)
cat_result = torch.cat(out_list, dim=0)
if op_type == "expand":
out_tensor += cat_result
else:
out_tensor.copy_(cat_result)
return
def generate_data(
batches,
hidden_size,
......
......@@ -22,9 +22,6 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
assert self.lora_config is None, "cpu backend doesn't support LoRA"
#
# Environment variables for CPU executor
......
from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
sgmv_expand, sgmv_expand_slice,
sgmv_shrink)
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink",
]
import torch
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
add_inputs)
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
if add_inputs:
output_tensor[:, :outputs.shape[1]] += outputs[:limit, :]
else:
output_tensor[:, :outputs.shape[1]] = outputs[:limit, :]
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
scaling)
def bgmv_shrink(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
slice_offset, slice_size, add_inputs)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
if add_inputs:
output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_shrink",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment