Unverified Commit 7ecee343 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Kernel][RFC] Refactor the punica kernel based on Triton (#5036)

parent 7eb0cb4a
...@@ -181,9 +181,6 @@ class cmake_build_ext(build_ext): ...@@ -181,9 +181,6 @@ class cmake_build_ext(build_ext):
# match. # match.
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
if _install_punica():
cmake_args += ['-DVLLM_INSTALL_PUNICA_KERNELS=ON']
# #
# Setup parallelism and build tool # Setup parallelism and build tool
# #
...@@ -274,10 +271,6 @@ def _build_custom_ops() -> bool: ...@@ -274,10 +271,6 @@ def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu() return _is_cuda() or _is_hip() or _is_cpu()
def _install_punica() -> bool:
return envs.VLLM_INSTALL_PUNICA_KERNELS
def get_hipcc_rocm_version(): def get_hipcc_rocm_version():
# Run the hipcc --version command # Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'], result = subprocess.run(['hipcc', '--version'],
...@@ -446,9 +439,6 @@ if _is_cuda() or _is_hip(): ...@@ -446,9 +439,6 @@ if _is_cuda() or _is_hip():
if _build_custom_ops(): if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))
if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
package_data = { package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
} }
......
import gc import gc
from unittest.mock import patch
import pytest import pytest
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential, from vllm.model_executor.layers.ops.sample import (_sample_triton,
_uniform_to_exponential,
sample) sample)
from vllm.model_executor.sampling_metadata import SamplingTensors from vllm.model_executor.sampling_metadata import SamplingTensors
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.triton_utils.libentry import LibEntry
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
get_num_triton_sampler_splits) get_num_triton_sampler_splits)
...@@ -76,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of, ...@@ -76,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
seeds = torch.randint(1, seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs), torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask) device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( #The current _sample_triton does not utilize the
probs=probs, # libentry decoration. The purpose of adding this patch is to test
logprobs=logprobs, # the correctness of libentry.
sample_indices=sample_indices, with patch("vllm.model_executor.layers.ops.sample._sample_triton",
seeds=seeds, LibEntry(_sample_triton)):
max_best_of=max_best_of, sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
modify_greedy_probs=modify_greedy_probs, probs=probs,
save_logprobs=save_logprobs, logprobs=logprobs,
_save_modified_probs=True) sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of) assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs): for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
...@@ -130,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of, ...@@ -130,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of, def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size): modify_greedy_probs, seed, vocab_size):
set_random_seed(seed) set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2 prompt_sizes = [16, 32, 64, 128] * 2
samples = 8 samples = 8
...@@ -157,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of, ...@@ -157,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
seeds = torch.randint(1, seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples), torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask) device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, _ = sample( #ditto
probs=probs, with patch("vllm.model_executor.layers.ops.sample._sample_triton",
logprobs=logprobs, LibEntry(_sample_triton)):
sample_indices=sample_indices, sampled_tokens, sampled_logprobs, _ = sample(
seeds=seeds, probs=probs,
max_best_of=max_best_of, logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs, sample_indices=sample_indices,
save_logprobs=True) seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of) assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of) assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices): for i, t in enumerate(sample_indices):
......
...@@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files): ...@@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
expected_lora_output = [ expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n", "more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n",
"so little time\nAuthor: Frank Zappa\n", "so little time.\nAuthor: Frank Zappa\n",
] ]
output1 = do_sample(llm, gemma_lora_files, lora_id=1) output1 = do_sample(llm, gemma_lora_files, lora_id=1)
......
...@@ -26,7 +26,8 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ...@@ -26,7 +26,8 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
# yapf: enable # yapf: enable
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights, convert_mapping) PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -47,6 +48,9 @@ TOLERANCES = { ...@@ -47,6 +48,9 @@ TOLERANCES = {
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
# We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
def get_random_id_to_index(num_loras: int, def get_random_id_to_index(num_loras: int,
...@@ -182,10 +186,12 @@ def create_random_inputs( ...@@ -182,10 +186,12 @@ def create_random_inputs(
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: @pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -204,7 +210,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -204,7 +210,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
embedding, lora_embedding = create_random_embedding_layer() embedding, lora_embedding = create_random_embedding_layer()
lora_embedding.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
layer=lora_embedding, layer=lora_embedding,
...@@ -217,12 +223,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -217,12 +223,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
...@@ -255,12 +261,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -255,12 +261,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
expected_result = embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs))
...@@ -278,11 +284,13 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -278,11 +284,13 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device, def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size) -> None: vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -318,6 +326,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -318,6 +326,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
generate_embeddings_tensor=256, generate_embeddings_tensor=256,
) )
lora_embedding.set_mapping(punica_wrapper)
# All embeddings tensors have the same shape. # All embeddings tensors have the same shape.
embeddings_tensors = [ embeddings_tensors = [
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
...@@ -334,8 +343,12 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -334,8 +343,12 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size,
lora_config.lora_extra_vocab_size)
original_inputs = deepcopy(inputs) original_inputs = deepcopy(inputs)
# Force some of the inputs to be in the extended embeddings range # Force some of the inputs to be in the extended embeddings range
...@@ -349,11 +362,6 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -349,11 +362,6 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
(embedding_id + 1) * embeddings_tensor_len - 1) (embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1 original_input_[-2] = vocab_size + embeddings_tensor_len - 1
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
expanded_embedding.weight[vocab_size:vocab_size + expanded_embedding.weight[vocab_size:vocab_size +
(embeddings_tensor_len * (embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors) max_loras)] = torch.cat(embeddings_tensors)
...@@ -390,15 +398,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -390,15 +398,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs) original_inputs = deepcopy(inputs)
lora_mapping = LoRAMapping(index_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(original_inputs)) lora_result = lora_embedding(torch.cat(original_inputs))
expected_result = expanded_embedding(torch.cat(inputs)) expected_result = expanded_embedding(torch.cat(inputs))
...@@ -413,11 +419,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -413,11 +419,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_lm_head_logits_processor(dist_init, num_loras, device, @pytest.mark.parametrize("stage", STAGES)
vocab_size) -> None: def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -443,7 +451,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -443,7 +451,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, logits_processor, lora_logits_processor = _pretest() linear, logits_processor, lora_logits_processor = _pretest()
lora_logits_processor.set_mapping(punica_wrapper)
# NOTE: all the generated loras share the same embeddings tensor. # NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
...@@ -461,17 +469,17 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -461,17 +469,17 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
input_ = torch.rand(20, 1024) is_prefill=stage)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_logits_processor.set_mapping(*mapping_info, ) input_ = torch.rand(20, 1024)
lora_result = lora_logits_processor._get_logits( lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs), hidden_states=torch.cat(inputs),
...@@ -510,12 +518,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -510,12 +518,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, is_prefill=stage)
vocab_size, punica_wrapper.update_metadata(
lora_config.lora_extra_vocab_size) lora_mapping,
lora_logits_processor.set_mapping(*mapping_info, ) id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_logits_processor._get_logits( lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs), hidden_states=torch.cat(inputs),
...@@ -538,10 +550,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -538,10 +550,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
@pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device) -> None: device, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -575,7 +589,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -575,7 +589,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer() linear, lora_linear = create_random_linear_parallel_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
layer=lora_linear, layer=lora_linear,
...@@ -589,16 +603,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -589,16 +603,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping( is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
512, 512,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_linear.set_mapping(*mapping_info, )
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
...@@ -628,11 +642,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -628,11 +642,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size) 512, lora_config.lora_extra_vocab_size)
lora_linear.set_mapping(*mapping_info, )
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
...@@ -649,10 +664,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -649,10 +664,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device) -> None: device, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -707,7 +724,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -707,7 +724,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer() linear, lora_linear = create_column_parallel_packed_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras( lora_dict, sublora_dict = populate_loras(
id_to_index, id_to_index,
layer=lora_linear, layer=lora_linear,
...@@ -722,16 +739,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -722,16 +739,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
512, 512,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
...@@ -762,16 +780,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -762,16 +780,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
512, 512,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_linear.set_mapping(*mapping_info) # lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
...@@ -803,7 +823,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -803,7 +823,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -825,6 +845,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -825,6 +845,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
is_neox_style, is_neox_style,
) )
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope.set_mapping(punica_wrapper)
lora_rope.create_lora_weights(max_loras, lora_config) lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base, linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, { is_neox_style, {
...@@ -840,6 +861,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -840,6 +861,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
input_range=(0, lora_config.lora_extra_vocab_size), input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors), long_lora_context = LongContextLoRAContext(list(scaling_factors),
rotary_dim) rotary_dim)
...@@ -854,7 +876,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -854,7 +876,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
for i in range(len(scaling_factors)): for i in range(len(scaling_factors)):
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
scaling_factors[i], 0) scaling_factors[i], 0)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
...@@ -862,7 +884,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -862,7 +884,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
long_lora_context=long_lora_context, long_lora_context=long_lora_context,
) )
lora_rope.set_mapping(*mapping_info) # lora_rope.set_mapping(*mapping_info)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query = torch.randn(batch_size,
......
import pytest
import torch
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
from .utils import DummyLoRAManager
TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
QKV_TENSOR_SIZES = [
(8192, 1024, 1024),
(8192 // 8, 1024 // 8, 1024 // 8),
(4096, 4096, 4096),
(4096 // 2, 4096 // 2, 4096 // 2),
]
BATCH_SIZES = [8, 32, 256]
RANKS = [8]
DTYPES = [torch.float16]
TOLERANCES = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora(m, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
module_name = "module"
weight = torch.rand([m, n], device="cuda", dtype=dtype)
manager.init_random_lora(module_name, weight, rank=rank)
lora = manager.get_module_lora(module_name)
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling
lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="cuda",
dtype=dtype)
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="cuda",
dtype=dtype)
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T
output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora(
input, lora_a_stack, lora_b_stack,
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
output)
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="cuda"), output)
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
if m % 2 != 0:
pytest.skip("m must be divisible by 2")
if m // 2 not in TENSOR_SIZES:
pytest.skip("m//2 must be in TENSOR_SIZES")
manager = DummyLoRAManager()
module_name = "module"
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)
manager.init_random_lora(module_name + "1", weight, rank=rank)
lora_1 = manager.get_module_lora(module_name + "1")
manager.init_random_lora(module_name + "2", weight, rank=rank)
lora_2 = manager.get_module_lora(module_name + "2")
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
],
dim=1)
lora_a_stacks = [
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T
output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (m // 2, m // 2))
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
module_name = "module"
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)
manager.init_random_lora(module_name + "q", weight_q, rank=rank)
lora_q = manager.get_module_lora(module_name + "q")
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
lora_k = manager.get_module_lora(module_name + "k")
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
lora_v = manager.get_module_lora(module_name + "v")
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
],
dim=1)
lora_a_stacks = [
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
lora_a_stacks[2][i][0] = lora_v.lora_a.T
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T
output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (qkv[0], qkv[1], qkv[2]))
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
# Based on code from https://github.com/punica-ai/punica
import pytest
import torch
import vllm.lora.punica as punica
def assert_close(a, b):
rtol, atol = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
torch.float32: (None, None),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def _lora_ref_impl(
y_final: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
y_stage_1 = torch.empty(
(x.size(0), wa_T_all.size(-2)),
dtype=torch.float32,
device=x.device,
)
bs = x.shape[0]
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
if wb_T_all is not None:
wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
-2).to(torch.float32)
tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0)
y_final[i] += ((tmp @ wb).squeeze(0) *
s if wb_T_all is not None else y_stage_1[i])
return y_final, y_stage_1
H1 = H2 = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
H2 = [64] + H2
R = [1, 2, 4]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("r", R)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
bs = 32
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, r, dtype=dtype, device=device)
y_ref = y.clone()
_lora_ref_impl(
y_ref,
x,
wa_T_all,
None,
indices,
layer_idx,
1.0,
)
y_our = y.clone()
punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_lora_correctness(dtype_str, h1, h2, seed, device):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
torch.set_default_device(device)
wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype)
y_ref = y.clone()
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
y_our = y.clone()
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
scale)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_lora_correctness_slice(dtype_str, h1, h2, seed, device):
if h2 % 3 != 0 or h2 // 3 not in H1:
pytest.skip("h2 must be divisible by 3 and in supported shapes")
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
torch.set_default_device(device)
wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype)
s = h2 // 3
y_ref = y.clone()
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale)
y_our = y.clone()
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale, 0, s)
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale, s, s)
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale, s * 2, s)
assert_close(y_ref[:, :s], y_our[:, :s])
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])
"""
This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
import random
from unittest.mock import patch
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
HIDDEN_SIZES = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
#The size of TP
divisibility = [1, 2, 4, 8, 16, 32, 64]
all_hidden_size = []
for div in divisibility:
for hidden_size in HIDDEN_SIZES:
all_hidden_size.append(hidden_size // div)
HIDDEN_SIZES = list(set(all_hidden_size))
BATCHES = [4]
NUM_LORA = [4]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [32]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
scaling,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)
"""
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
import random
from unittest.mock import patch
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
HIDDEN_SIZES = [3424, 4096, 4097]
BATCHES = [1, 4, 16, 32]
NUM_LORA = [1, 4, 8, 16, 32, 64, 128]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
scaling,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)
if __name__ == "__main__":
from itertools import product
lst = list(
product(
BATCHES,
NUM_LORA,
MAX_RANKS,
[1.0],
[torch.float16],
["expand"],
SEED,
CUDA_DEVICES,
))
for ele in lst:
test_punica_bgmv(*ele)
print(f"{ele},pass")
...@@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size): ...@@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
# if torch.cuda.device_count() < tp_size: # if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(model=model.model_path, llm = vllm.LLM(
enable_lora=True, model=model.model_path,
max_num_seqs=16, enable_lora=True,
max_loras=4, max_num_seqs=16,
max_model_len=400, max_loras=4,
tensor_parallel_size=tp_size, max_model_len=400,
quantization=model.quantization, tensor_parallel_size=tp_size,
trust_remote_code=True) gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
if model.quantization is None: if model.quantization is None:
expected_no_lora_output = [ expected_no_lora_output = [
...@@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model): ...@@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model):
# if torch.cuda.device_count() < 2: # if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}") # pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
llm_tp1 = vllm.LLM(model=model.model_path, llm_tp1 = vllm.LLM(
enable_lora=True, model=model.model_path,
max_num_seqs=16, enable_lora=True,
max_loras=4, max_num_seqs=16,
tensor_parallel_size=1, max_loras=4,
quantization=model.quantization, tensor_parallel_size=1,
trust_remote_code=True) gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
del llm_tp1 del llm_tp1
cleanup() cleanup()
llm_tp2 = vllm.LLM(model=model.model_path, llm_tp2 = vllm.LLM(
enable_lora=True, model=model.model_path,
max_num_seqs=16, enable_lora=True,
max_loras=4, max_num_seqs=16,
tensor_parallel_size=2, max_loras=4,
quantization=model.quantization) tensor_parallel_size=2,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
del llm_tp2 del llm_tp2
......
...@@ -86,3 +86,151 @@ class DummyLoRAManager: ...@@ -86,3 +86,151 @@ class DummyLoRAManager:
packed_lora = PackedLoRALayerWeights.pack(base_loras) packed_lora = PackedLoRALayerWeights.pack(base_loras)
self.set_module_lora(module_name, packed_lora) self.set_module_lora(module_name, packed_lora)
return packed_lora return packed_lora
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def ref_torch_groupgemm(
out_tensor,
inputs,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
) -> torch.Tensor:
out_list = []
current_offset = 0
for lora_index, b_length in zip(range(batches), seq_len_tensor):
input_weight = inputs[current_offset:b_length + current_offset, :]
current_offset += b_length
lora_weight = lora_weights[lora_indices_tensor[lora_index]]
result = torch.nn.functional.linear(input_weight, lora_weight)
result *= scaling
out_list.append(result)
cat_result = torch.cat(out_list, dim=0)
if op_type == "expand":
out_tensor += cat_result
else:
out_tensor.copy_(cat_result)
return
def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
op_type, device):
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0,
).to(device)
total_tokens = seq_len_tensor.sum()
if op_type == "shrink":
inputs_tensor = torch.rand((total_tokens, hidden_size),
dtype=dtype).to(device)
lora_weights = torch.rand(
(lora_nums, max_rank, hidden_size), # col-major
dtype=dtype,
).to(device)
# shrink op need atomic_add, so output is initinized by 0
ref_out_tensor = torch.zeros((total_tokens, max_rank),
dtype=dtype,
device=inputs_tensor.device)
# NOTE shrink kernel using torch.float32 as output type
our_out_tensor = torch.zeros((total_tokens, max_rank),
dtype=torch.float32).to(device)
else:
inputs_tensor = torch.rand(
(total_tokens, max_rank),
dtype=dtype,
).to(device)
lora_weights = torch.rand(
(lora_nums, hidden_size, max_rank), # col-major
dtype=dtype,
).to(device)
# expand op needs to complete y+=a@lora_b, so output is
# initinized randomly
ref_out_tensor = torch.rand(
(total_tokens, hidden_size),
dtype=dtype,
).to(device)
# Ensure the same input.
our_out_tensor = ref_out_tensor.clone()
lora_indices_tensor = torch.randint(0,
lora_nums - 1 if lora_nums > 1 else 1,
(batches, )).to(device)
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
current_offset = 0
for b_id in range(batches):
lora_index = lora_indices_tensor[b_id]
indices[current_offset:current_offset +
seq_len_tensor[b_id]].copy_(lora_index)
current_offset += seq_len_tensor[b_id].item()
return (
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
)
def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
seq_length, dtype, nslices, device):
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
(batches, )).to(device)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
dim=0,
).to(device)
total_tokens = seq_len_tensor.sum()
inputs_tensor = torch.rand(
(total_tokens, max_rank),
dtype=dtype,
).to(device)
lora_weights_lst = []
for _ in range(nslices):
lora_weights_lst.append(
torch.rand(
(lora_nums, hidden_size, max_rank), # col-major
dtype=dtype,
).to(device))
# expand op needs to complete y+=a@lora_b, so output is
# initinized randomly
ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices),
dtype=dtype).to(device)
# Ensure the same input.
our_out_tensor = ref_out_tensor.clone()
lora_indices_tensor = torch.randint(0,
lora_nums - 1 if lora_nums > 1 else 1,
(batches, ))
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
current_offset = 0
for b_id in range(batches):
lora_index = lora_indices_tensor[b_id]
indices[current_offset:current_offset +
seq_len_tensor[b_id]] = lora_index.item()
current_offset += seq_len_tensor[b_id].item()
lora_indices_tensor = lora_indices_tensor.to(device)
return (
inputs_tensor,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
)
...@@ -13,12 +13,9 @@ try: ...@@ -13,12 +13,9 @@ try:
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
import vllm._moe_C
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
# ruff: noqa: F401 # ruff: noqa: F401
import vllm._punica_C import vllm._moe_C
def is_custom_op_supported(op_name: str) -> bool: def is_custom_op_supported(op_name: str) -> bool:
...@@ -519,43 +516,6 @@ def register_graph_buffers(fa: int, handles: List[str], ...@@ -519,43 +516,6 @@ def register_graph_buffers(fa: int, handles: List[str],
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
# punica
def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
) -> None:
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
scale)
def dispatch_bgmv_low_level(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
h_in: int,
h_out: int,
y_offset: int,
) -> None:
torch.ops._punica_C.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
h_in,
h_out,
y_offset,
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456 # temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0 # TODO: remove this in v0.6.0
names_and_values = globals() names_and_values = globals()
......
...@@ -45,7 +45,6 @@ if TYPE_CHECKING: ...@@ -45,7 +45,6 @@ if TYPE_CHECKING:
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
VLLM_INSTALL_PUNICA_KERNELS: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False VERBOSE: bool = False
...@@ -94,10 +93,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -94,10 +93,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_PRECOMPILED": "VLLM_USE_PRECOMPILED":
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")),
# If set, vllm will install Punica kernels
"VLLM_INSTALL_PUNICA_KERNELS":
lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))),
# CMake build type # CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo" # If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo"
......
...@@ -14,7 +14,6 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA, ...@@ -14,7 +14,6 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora, QKVParallelLinearWithLora,
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
...@@ -28,7 +27,7 @@ def _fully_sharded_can_replace(can_replace): ...@@ -28,7 +27,7 @@ def _fully_sharded_can_replace(can_replace):
def dec(*args, **kwargs): def dec(*args, **kwargs):
return (can_replace(*args, **kwargs) return (can_replace(*args, **kwargs)
and kwargs['lora_config'].fully_sharded_loras) and kwargs["lora_config"].fully_sharded_loras)
return dec return dec
...@@ -59,25 +58,30 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -59,25 +58,30 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), buffer = torch.zeros(
dtype=torch.float32, (x.shape[0], self.lora_a_stacked.shape[2]),
device=x.device) dtype=torch.float32,
device=x.device,
bgmv(buffer, x, self.lora_a_stacked, )
self.indices[:self.indices_len[0]], 0, 1.0) self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer) buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked, self.punica_wrapper.add_expand(output,
self.indices[:self.indices_len[0]], 0, 1.0) buffer,
self.lora_b_stacked,
add_input=True)
# now have column partitioned output # now have column partitioned output
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
return output return output
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
source_layer=source_layer, source_layer=source_layer,
...@@ -88,14 +92,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): ...@@ -88,14 +92,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
) )
def _mcp_apply(x, bias, layer): def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
""" """
MergedColumnParallelLinearWithShardedLoRA and MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method. LoRa weight application method.
The main difference is the step by shard_size for lora_b which can The main difference is the step by shard_size for lora_b which can
vary for MergedQKVParallelLinearWithShardedLora but is constant for vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA. MergedColumnParallelLinearWithShardedLoRA.
""" """
# expecting 2 for column parallel and 3 for qkv # expecting 2 for column parallel and 3 for qkv
...@@ -104,21 +108,27 @@ def _mcp_apply(x, bias, layer): ...@@ -104,21 +108,27 @@ def _mcp_apply(x, bias, layer):
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), buffers = torch.zeros(
dtype=torch.float32, (n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
device=x.device) dtype=torch.float32,
device=x.device,
)
for idx in range(n): for idx in range(n):
bgmv(buffers[idx], x, layer.lora_a_stacked[idx], layer.punica_wrapper.add_shrink(buffers[idx], x,
layer.indices[:layer.indices_len[0]], 0, 1.0) layer.lora_a_stacked[idx], 1.0)
buffers = tensor_model_parallel_all_gather(buffers) buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0 left_offset = 0
for idx in range(n): for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2] shard_size = layer.lora_b_stacked[idx].shape[2]
dispatch_bgmv_low_level(output, buffers[idx], layer.punica_wrapper.add_expand_slice(
layer.lora_b_stacked[idx], output,
layer.indices[:layer.indices_len[0]], 0, 1.0, buffers[idx],
left_offset, shard_size) layer.lora_b_stacked[idx],
left_offset,
shard_size,
add_input=True,
)
left_offset += shard_size left_offset += shard_size
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
...@@ -129,7 +139,7 @@ def _mcp_apply(x, bias, layer): ...@@ -129,7 +139,7 @@ def _mcp_apply(x, bias, layer):
class MergedColumnParallelLinearWithShardedLoRA( class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA): MergedColumnParallelLinearWithLoRA):
""" """
Differs from MergedColumnParallelLinearWithLoRA by slicing the Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also. LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
...@@ -145,7 +155,8 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -145,7 +155,8 @@ class MergedColumnParallelLinearWithShardedLoRA(
lora_a = [ lora_a = [
lora_a[0][:, lora_a[0][:,
output_start_idx:output_start_idx + output_shard_size], output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] lora_a[1][:,
output_start_idx:output_start_idx + output_shard_size],
] ]
return lora_a return lora_a
...@@ -155,9 +166,13 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -155,9 +166,13 @@ class MergedColumnParallelLinearWithShardedLoRA(
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
source_layer=source_layer, source_layer=source_layer,
...@@ -170,7 +185,7 @@ class MergedColumnParallelLinearWithShardedLoRA( ...@@ -170,7 +185,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
""" """
Differs from QKVParallelLinearWithLora by slicing the Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also. LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
...@@ -193,14 +208,13 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): ...@@ -193,14 +208,13 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer) buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked, self.punica_wrapper.add_expand(output,
self.indices[:self.indices_len[0]], 0, 1.0) buffer,
self.lora_b_stacked,
add_input=True)
# now have column partitioned output # now have column partitioned output
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
return output return output
...@@ -237,7 +251,7 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): ...@@ -237,7 +251,7 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
lora_a = [ lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
] ]
return lora_a return lora_a
...@@ -247,9 +261,13 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): ...@@ -247,9 +261,13 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
source_layer=source_layer, source_layer=source_layer,
...@@ -262,11 +280,11 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): ...@@ -262,11 +280,11 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
""" """
Differs from RowParallelLinearWithLoRA by slicing the Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also. LoRA B's also.
Based on S-LoRA, slicing happens along the output dim. Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA. layer and column partitioned output from the LoRA.
""" """
...@@ -283,11 +301,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -283,11 +301,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), buffer = torch.zeros(
dtype=torch.float32, (x.shape[0], self.lora_a_stacked.shape[2]),
device=x.device) dtype=torch.float32,
bgmv(buffer, x, self.lora_a_stacked, device=x.device,
self.indices[:self.indices_len[0]], 0, 1.0) )
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer) buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce # following S-LoRA, allows the fusing of all_gather and all_reduce
...@@ -298,18 +318,21 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): ...@@ -298,18 +318,21 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used # reduced before being used
shard_size = self.lora_b_stacked.shape[2] shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, self.punica_wrapper.add_expand_slice(output, buffer,
self.indices[:self.indices_len[0]], 0, 1.0, self.lora_b_stacked, start_idx,
start_idx, shard_size) shard_size)
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
return output return output
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
source_layer=source_layer, source_layer=source_layer,
......
...@@ -17,7 +17,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -17,7 +17,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import PunicaWrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -55,88 +55,17 @@ def _not_fully_sharded_can_replace(can_replace): ...@@ -55,88 +55,17 @@ def _not_fully_sharded_can_replace(can_replace):
""" """
def dec(*args, **kwargs): def dec(*args, **kwargs):
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
condition = (not kwargs['lora_config'].fully_sharded_loras condition = (not kwargs["lora_config"].fully_sharded_loras
if decorate else True) if decorate else True)
return can_replace(*args, **kwargs) and condition return can_replace(*args, **kwargs) and condition
return dec return dec
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
indices: torch.Tensor,
output: torch.Tensor,
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
lora_b_stacked: (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)
def _apply_lora_packed_nslice(
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
This method is used for layers that are composed of multiple sublayers
(slices) packed together.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx in range(len(output_slices)):
add_lora_slice(output, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
return output.view_as(org_output)
@dataclass @dataclass
class LoRAMapping(AdapterMapping): class LoRAMapping(AdapterMapping):
pass is_prefill: bool = False
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
...@@ -154,10 +83,11 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -154,10 +83,11 @@ class BaseLayerWithLoRA(nn.Module):
... ...
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None,
) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
... ...
...@@ -177,20 +107,18 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -177,20 +107,18 @@ class BaseLayerWithLoRA(nn.Module):
def set_mapping( def set_mapping(
self, self,
base_indices: torch.Tensor, punica_wrapper: PunicaWrapper,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
): ):
"""Sets the mapping indices.""" self.punica_wrapper: PunicaWrapper = punica_wrapper
...
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError raise NotImplementedError
...@@ -259,10 +187,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -259,10 +187,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2], self.lora_a_stacked.shape[2],
) )
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -285,40 +209,27 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -285,40 +209,27 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor. index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1]].copy_(embeddings_tensor, non_blocking=True) shape[1], ].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None: if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy # TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part # everything, just the modified part
embeddings = self.embeddings_tensors.view( embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2] self.embeddings_tensors.shape[2],
)[self.embeddings_slice[0]:self.embeddings_slice[1]] )[self.embeddings_slice[0]:self.embeddings_slice[1]]
assert self.embeddings_weights is not None assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.embeddings_indices = embeddings_indices
self.indices_len = indices_len
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embedding_len = self.indices_len[3] embeddings_indices = self.punica_wrapper.embeddings_indices
indices = self.embeddings_indices[1][:embedding_len].view_as(x) indices = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = self.embeddings_indices[0][:embedding_len].view_as(x) indices = embeddings_indices[0].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
...@@ -329,22 +240,32 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -329,22 +240,32 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
if full_lora_a_embeddings.ndim == 3: if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1) full_lora_a_embeddings.shape[1],
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, -1,
self.indices[:self.indices_len[0]], 0, 1.0) )
# Embedding layer only need expand op
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
return full_output.view_as(full_output_org) return full_output.view_as(full_output_org)
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is VocabParallelEmbedding
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
""" """
LoRA on top of ColumnParallelLinear layer. LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism. LoRA B is sliced for tensor parallelism.
""" """
...@@ -357,10 +278,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -357,10 +278,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
lora_a_output_size_per_partition = ( lora_a_output_size_per_partition = (
...@@ -384,10 +306,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -384,10 +306,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
) )
self.output_dim = self.lora_b_stacked.shape[2] self.output_dim = self.lora_b_stacked.shape[2]
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
...@@ -423,28 +341,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -423,28 +341,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora( self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
x, self.lora_b_stacked, 1.0)
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output return output
def forward(self, input_): def forward(self, input_):
...@@ -473,9 +374,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -473,9 +374,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ColumnParallelLinear or ( return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1) and len(packed_modules_list) == 1)
...@@ -494,10 +399,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -494,10 +399,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
super().__init__(base_layer) super().__init__(base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config self.lora_config = lora_config
n_slices = 2 n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices if not (len(self.base_layer.output_sizes) == n_slices
...@@ -533,8 +439,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -533,8 +439,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.output_dim = self.lora_b_stacked[0].shape[2] self.output_dim = self.lora_b_stacked[0].shape[2]
# Lazily initialized.
self.indices: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
...@@ -556,7 +460,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -556,7 +460,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
lora_b = [ lora_b = [
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] lora_b[0][:, start_idx:end_idx],
lora_b[1][:, start_idx:end_idx],
] ]
return lora_b return lora_b
...@@ -591,34 +496,33 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -591,34 +496,33 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice( self.punica_wrapper.add_lora_packed_nslice(
x, output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
self.lora_a_stacked, (self.output_dim, self.output_dim))
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
(self.output_dim, self.output_dim),
)
return output return output
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
return type(source_layer) is MergedColumnParallelLinear and len( lora_config: LoRAConfig,
packed_modules_list) == 2 packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2)
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
""" """
ColumnParallelLinear layer that is specifically designed for ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b, qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer. only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks. must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have Q slice may have different shape than K and V slices (which both have
the same shape). the same shape).
""" """
...@@ -696,10 +600,11 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -696,10 +600,11 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
super().__init__(base_layer) super().__init__(base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -767,11 +672,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -767,11 +672,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
), ),
) )
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, self.output_slices = (
self.kv_proj_shard_size) self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.packed_indices: Optional[torch.Tensor] = None self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None
# lazily initialized. # lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int] self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
...@@ -794,15 +703,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -794,15 +703,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size * lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)] (self.q_shard_id + 1), ]
if lora_b[1] is not None: if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size * lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)] (self.kv_shard_id + 1), ]
if lora_b[2] is not None: if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size * lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)] (self.kv_shard_id + 1), ]
lora_b = [lora_b_q, lora_b_k, lora_b_v] lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b return lora_b
...@@ -851,23 +760,23 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -851,23 +760,23 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice( self.punica_wrapper.add_lora_packed_nslice(output, x,
x, self.lora_a_stacked,
self.lora_a_stacked, self.lora_b_stacked, 1.0,
self.lora_b_stacked, self.output_slices)
self.indices[:self.indices_len[0]],
output,
self.output_slices,
)
return output return output
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
return type(source_layer) is QKVParallelLinear and len( lora_config: LoRAConfig,
packed_modules_list) == 3 packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is QKVParallelLinear
and len(packed_modules_list) == 3)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...@@ -880,10 +789,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -880,10 +789,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config self.lora_config = lora_config
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
...@@ -911,9 +821,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -911,9 +821,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -950,27 +857,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -950,27 +857,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x) output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora( self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
x, self.lora_b_stacked, 1.0)
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output return output
def forward(self, input_): def forward(self, input_):
...@@ -1013,14 +903,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -1013,14 +903,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property @property
def weight(self): def weight(self):
return self.base_layer.weight if hasattr( return (self.base_layer.weight if hasattr(self.base_layer, "weight")
self.base_layer, "weight") else self.base_layer.qweight else self.base_layer.qweight)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is RowParallelLinear return type(source_layer) is RowParallelLinear
...@@ -1125,10 +1019,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1125,10 +1019,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=torch.long) dtype=torch.long)
else: else:
self.sharded_to_full_mapping_gpu = None self.sharded_to_full_mapping_gpu = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
self.indices_padded: torch.Tensor
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -1154,19 +1044,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1154,19 +1044,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
index, :embeddings_tensor.shape[0], :embeddings_tensor. index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ] = embeddings_tensor shape[1], ] = embeddings_tensor
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
self.indices_padded = sampler_indices_padded
self.indices_len = indices_len
def _get_logits( def _get_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1212,38 +1089,37 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1212,38 +1089,37 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
out=lora_logits[:-1]) out=lora_logits[:-1])
lora_logits[-1] = float("-inf") lora_logits[-1] = float("-inf")
lora_logits = lora_logits.mT lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
lora_logits = (lora_logits.reshape( lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2], lora_logits.shape[2],
).index_select(0, ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
self.indices_padded[:self.indices_len[2]]).nan_to_num_( posinf=float("inf"),
nan=float("-inf"), neginf=float("-inf")))
posinf=float("inf"),
neginf=float("-inf")))
logits[:, logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits lora_logits.shape[1], ] = lora_logits
_apply_lora( # LogitsProcessorWithLoRA always using bgmv
hidden_states, self.punica_wrapper.add_lora_logits(logits, hidden_states,
self.lora_a_stacked, self.lora_a_stacked,
self.lora_b_stacked, self.lora_b_stacked, 1.0)
self.indices[:self.indices_len[1]],
logits,
)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size] logits = logits[:, :self.base_layer.vocab_size]
return logits return logits
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs) return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False
...@@ -1259,9 +1135,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): ...@@ -1259,9 +1135,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
def __init__(self, base_layer: RotaryEmbedding) -> None: def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
# Lazily initialized
self.long_lora_indices: torch.Tensor
self.indices_len: List[int]
@property @property
def scaling_factors(self): def scaling_factors(self):
...@@ -1277,9 +1150,8 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): ...@@ -1277,9 +1150,8 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> None: ) -> None:
scaling_factors = list( scaling_factors = (list(lora_config.long_lora_scaling_factors)
lora_config.long_lora_scaling_factors if lora_config.long_lora_scaling_factors else [])
) if lora_config.long_lora_scaling_factors else []
base_scaling_factor = (self.base_layer.scaling_factor if isinstance( base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0) self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted( scaling_factors = sorted(
...@@ -1306,18 +1178,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): ...@@ -1306,18 +1178,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
): ):
... ...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.long_lora_indices = long_lora_indices
self.indices_len = indices_len
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -1328,19 +1188,24 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): ...@@ -1328,19 +1188,24 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
positions, positions,
query, query,
key, key,
offsets=self.long_lora_indices[:self.indices_len[4]]) offsets=self.punica_wrapper.long_lora_indices,
)
@property @property
def scaling_factor_to_offset(self) -> Dict[float, int]: def scaling_factor_to_offset(self) -> Dict[float, int]:
return self.base_layer.scaling_factor_to_offset return self.base_layer.scaling_factor_to_offset
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(
lora_config: LoRAConfig, packed_modules_list: List, cls,
model_config: Optional[PretrainedConfig]) -> bool: source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is LinearScalingRotaryEmbedding or type( return (type(source_layer) is LinearScalingRotaryEmbedding
source_layer) is RotaryEmbedding or type(source_layer) is RotaryEmbedding)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return self.base_layer.extra_repr() return self.base_layer.extra_repr()
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, List, Optional, Type
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -21,6 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ...@@ -21,6 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora, LinearScalingRotaryEmbeddingWithLora,
LoRAMapping) LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.models.interfaces import SupportsLoRA
...@@ -43,115 +44,6 @@ class LongContextLoRAContext: ...@@ -43,115 +44,6 @@ class LongContextLoRAContext:
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
def convert_mapping(
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices, indices_len)
def get_lora_id(): def get_lora_id():
global _GLOBAL_LORA_ID global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1 _GLOBAL_LORA_ID += 1
...@@ -422,29 +314,12 @@ class LoRAModelManager(AdapterModelManager): ...@@ -422,29 +314,12 @@ class LoRAModelManager(AdapterModelManager):
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None self.long_lora_context: Optional[LongContextLoRAContext] = None
self.base_indices = torch.empty(self.max_num_batched_tokens, self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
dtype=torch.long, max_batches=self.max_num_seqs,
device="cuda") device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
# Scaling factor -> offset to the sin_cos_cache to it. # Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora. # Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: Dict[float, int] = {}
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4
super().__init__(model) super().__init__(model)
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy( self.supported_lora_modules = copy.deepcopy(
...@@ -536,28 +411,16 @@ class LoRAModelManager(AdapterModelManager): ...@@ -536,28 +411,16 @@ class LoRAModelManager(AdapterModelManager):
"Pinning is not supported in LoRAModelManager." "Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore "Use LRUCacheLoRAModelManager for pinning") # type: ignore
# TODO see if this can be vectorized
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded, # update lora states
embeddings_indices, long_lora_offsets_tensor, self.punica_wrapper.update_metadata(
indices_len) = convert_mapping(mapping, self.lora_index_to_id, mapping,
self.lora_slots + 1, self.vocab_size, self.lora_index_to_id,
self.lora_config.lora_extra_vocab_size, self.lora_slots + 1,
self.long_lora_context) self.vocab_size,
self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.lora_config.lora_extra_vocab_size,
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self.long_lora_context,
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( )
sampler_indices_padded)
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self.long_lora_indices.zero_()
# Maintain the reference
self.indices_len[:] = indices_len
def remove_all_adapters(self): def remove_all_adapters(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
...@@ -595,10 +458,8 @@ class LoRAModelManager(AdapterModelManager): ...@@ -595,10 +458,8 @@ class LoRAModelManager(AdapterModelManager):
self.model.config)) self.model.config))
self.register_module(module_name, new_module) self.register_module(module_name, new_module)
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices, # All lora layers share the same punica_wrapper based on reference.
self.sampler_indices_padded, new_module.set_mapping(self.punica_wrapper)
self.embeddings_indices,
self.long_lora_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA) assert isinstance(module, BaseLayerWithLoRA)
......
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Dict, Optional
import torch
import triton
import triton.language as tl
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n_c[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
if override_config:
config = override_config
else:
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Dict, Optional
import torch
import triton
import triton.language as tl
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
performance
"""
pid_sn = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_k = tl.arange(0, BLOCK_K)
offset_n = tl.arange(0, BLOCK_N)
if EVEN_K:
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
offset_k * xk_stride, ) # [BLOCK_K]
else:
tiled_a = tl.load(
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
mask=offset_k < K,
other=0,
) # [BLOCK_K]
# N must be divisible by SPLIT_N
split_n_length = tl.cdiv(N, SPLIT_N)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
slice_offset * cn_stride)
for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
< K)
c_mask = current_n < split_n_length
tiled_b = tl.load(
b_ptr + current_n[:, None] * lora_k_stride +
offset_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
@torch.inference_mode()
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'b weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert slice_size == lora_b_weights.size(-2)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_K = triton.next_power_of_2(K)
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
batches = lora_indices_tensor.size(0)
if override_config:
config = override_config
else:
config = get_lora_op_configs("expand", batches, N)
grid = lambda META: (
META["SPLIT_N"],
batches,
)
_bgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
slice_offset,
BLOCK_K=BLOCK_K,
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Dict, Optional
import torch
import triton
import triton.language as tl
from .utils import get_lora_op_configs
@triton.jit
def _bgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
scaling,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
"""
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
performance
"""
pid_sk = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
offset_n = tl.arange(0, BLOCK_N)
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
a_ptr = input_ptr + cur_batch * xm_stride
b_ptr = lora_ptr + l0_stride * lora_index
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
for k in range(0, K, BLOCK_K * SPLIT_K):
current_k = k + offset_k
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
tiled_a = tl.load(
a_ptr + current_k_c,
mask=current_k < K,
other=0.0,
) # [BLOCK_K]
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
tiled_b = tl.load(
b_ptr + offset_n[:, None] * lora_k_stride +
current_k[None, :] * lora_n_stride,
mask=b_ptr_mask,
other=0.0,
) # [BLOCK_N,BLOCK_K]
accumulator += tl.sum(tiled_a * tiled_b, 1)
accumulator *= scaling
offset_cn = tl.arange(0, BLOCK_N)
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
c_mask = offset_cn < N
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
@torch.inference_mode()
def bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
override_config: Optional[Dict[str, int]] = None,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
scaling (float): Scaling factor.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_a_weights.size(-1)
assert inputs.is_contiguous()
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
assert lora_a_weights.size(1) == 1
lora_a_weights = lora_a_weights.squeeze(dim=1)
else:
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
assert lora_a_weights.is_contiguous()
assert output_tensor.is_contiguous()
# TODO tuning this config
batches = lora_indices_tensor.size(0)
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_N = triton.next_power_of_2(N)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
config = get_lora_op_configs("bgmv_shrink", batches, K)
grid = lambda META: (
META["SPLIT_K"],
batches,
)
_bgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
output_tensor,
N,
K,
lora_indices_tensor,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_a_weights.stride(0),
lora_a_weights.stride(1),
lora_a_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_N=BLOCK_N,
**config,
)
return
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride, # 1
l0_stride, # hidden_size*max_rank
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
"""
The sgmv's expand triton kernel is based on GroupGEMM.
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
offset_k[None, :] * xk_stride, )
b_ptr = (lora_ptr + l0_stride * lora_index +
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * xk_stride
b_ptr += BLOCK_K * lora_n_stride
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
offset_cn[None, :] * cn_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@torch.inference_mode()
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
add_inputs: bool = False,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weights.size(1) == 1
lora_b_weights = lora_b_weights.squeeze(dim=1)
else:
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weights.is_contiguous()
# TODO tuning this config
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
BLOCK_M = 32
BLOCK_N = 32
BLOCK_K = 16
EVEN_K = K % BLOCK_K == 0
ADD_INPUTS = add_inputs
CAST_TYPE = False
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
batches,
)
_sgmv_expand_kernel[grid](
inputs,
lora_b_weights,
output_tensor,
N,
K,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
inputs.stride(0),
inputs.stride(1),
lora_b_weights.stride(0),
lora_b_weights.stride(1),
lora_b_weights.stride(2),
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
)
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment