Unverified Commit ddd1ef66 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix JambaForCausalLM LoRA (#14370)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent e5e03c2c
...@@ -6,7 +6,6 @@ from typing import TypedDict ...@@ -6,7 +6,6 @@ from typing import TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import safetensors
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules(): ...@@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0") return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)
adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
remove_unnecessary_weights(adapter_path)
return adapter_path
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def gemma_lora_files(): def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
MAX_TOKENS = 40
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: list[str]) -> list[str]:
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = ["Write a story about a sheep and a goat."]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)
expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output
...@@ -632,6 +632,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage, ...@@ -632,6 +632,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer() linear, lora_linear = create_random_linear_replicated_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
...@@ -757,6 +758,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -757,6 +758,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer() linear, lora_linear = create_random_linear_parallel_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
...@@ -904,6 +906,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -904,6 +906,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer() linear, lora_linear = create_column_parallel_packed_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras( lora_dict, sublora_dict = populate_loras(
id_to_index, id_to_index,
......
...@@ -274,6 +274,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -274,6 +274,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
) -> bool: ) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
...@@ -409,6 +413,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -409,6 +413,34 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
self.output_slices) self.output_slices)
return output return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> Optional[torch.Tensor]:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
...@@ -902,11 +934,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -902,11 +934,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
return output, output_bias return output, output_bias
@property
def weight(self):
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
else self.base_layer.qweight)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment