Unverified Commit 55509c21 authored by ErezSC42's avatar ErezSC42 Committed by GitHub
Browse files

[MODEL] LoRA support for Jamba model (#11209)


Signed-off-by: default avatarErez Schwartz <erezs@ai21.com>
parent 10141809
...@@ -4,6 +4,7 @@ from typing import Dict, List, TypedDict ...@@ -4,6 +4,7 @@ from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import safetensors
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules(): ...@@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0") return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)
adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
remove_unnecessary_weights(adapter_path)
return adapter_path
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def gemma_lora_files(): def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
......
from typing import List
import pytest
import torch
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
MAX_TOKENS = 40
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: List[str]) -> List[str]:
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = ["Write a story about a sheep and a goat."]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)
expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output
...@@ -42,12 +42,14 @@ class MambaMixer(CustomOp): ...@@ -42,12 +42,14 @@ class MambaMixer(CustomOp):
use_rms_norm: bool, use_rms_norm: bool,
rms_norm_has_weight: bool = True, rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation="silu"): activation="silu",
is_lora_enabled: bool = False):
super().__init__() super().__init__()
self.time_step_rank = time_step_rank self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm self.use_rms_norm = use_rms_norm
self.activation = activation self.activation = activation
self.is_lora_enabled = is_lora_enabled
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
...@@ -63,6 +65,7 @@ class MambaMixer(CustomOp): ...@@ -63,6 +65,7 @@ class MambaMixer(CustomOp):
self.in_proj = MergedColumnParallelLinear(hidden_size, self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2, [intermediate_size] * 2,
bias=use_bias) bias=use_bias)
# selective projection used to make dt, B and C input dependent # selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear( self.x_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -170,6 +173,12 @@ class MambaMixer(CustomOp): ...@@ -170,6 +173,12 @@ class MambaMixer(CustomOp):
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C # 3.a. input varying initialization of time_step, B and C
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split( time_step, B, C = torch.split(
...@@ -222,6 +231,11 @@ class MambaMixer(CustomOp): ...@@ -222,6 +231,11 @@ class MambaMixer(CustomOp):
scan_outputs = scan_outputs.transpose(0, 1) scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection # 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2, if self.is_lora_enabled:
-1))[0] # lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states return contextualized_states
...@@ -107,9 +107,11 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -107,9 +107,11 @@ class JambaMambaDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None: is_lora_enabled: Optional[bool] = False,
**kwargs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.is_lora_enabled = is_lora_enabled
self.mamba = MambaMixer(hidden_size= config.hidden_size, self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state, ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv, conv_kernel_size = config.mamba_d_conv,
...@@ -120,7 +122,9 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -120,7 +122,9 @@ class JambaMambaDecoderLayer(nn.Module):
use_bias = config.mamba_proj_bias, use_bias = config.mamba_proj_bias,
use_rms_norm=True, use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act) activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
)
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
...@@ -156,14 +160,13 @@ class JambaMambaDecoderLayer(nn.Module): ...@@ -156,14 +160,13 @@ class JambaMambaDecoderLayer(nn.Module):
class JambaAttentionDecoderLayer(nn.Module): class JambaAttentionDecoderLayer(nn.Module):
def __init__( def __init__(self,
self,
config: JambaConfig, config: JambaConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -287,17 +290,18 @@ class JambaModel(nn.Module): ...@@ -287,17 +290,18 @@ class JambaModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
def get_layer(prefix: str): def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1]) layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[ layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]] config.layers_block_type[layer_idx]]
return layer_class( return layer_class(config,
config,
layer_idx, layer_idx,
cache_config, cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
) **extra_kwargs)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
...@@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"k_proj", "k_proj",
"v_proj", "v_proj",
], ],
"in_proj": ["in_proj"],
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [ supported_lora_modules = [
"qkv_proj", "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
"o_proj", "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
"embed_tokens",
"lm_head",
] ]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
...@@ -446,7 +449,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -446,7 +449,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba) self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
......
...@@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module): ...@@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: MambaConfig, config: MambaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba" self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.is_lora_enabled = is_lora_enabled
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size, self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size, ssm_state_size=config.state_size,
...@@ -53,7 +55,8 @@ class MambaDecoderLayer(nn.Module): ...@@ -53,7 +55,8 @@ class MambaDecoderLayer(nn.Module):
use_rms_norm=self.is_falcon_mamba, use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps, rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act) activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -85,6 +88,7 @@ class MambaModel(nn.Module): ...@@ -85,6 +88,7 @@ class MambaModel(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -101,8 +105,10 @@ class MambaModel(nn.Module): ...@@ -101,8 +105,10 @@ class MambaModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MambaDecoderLayer( lambda prefix: MambaDecoderLayer(config,
config, cache_config=cache_config, quant_config=quant_config), cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size, self.norm_f = RMSNorm(config.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment