Unverified Commit 71221692 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Feature] Initial support for multi-LoRA serving (#1307)

parent c33d82a2
......@@ -27,7 +27,7 @@ srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "intere
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple
import safetensors.torch
import torch
from torch import nn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
super().__init__()
self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
return output
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# FIXME
assert lora_a_output.shape[-1] == self.lora_rank * 2
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# FIXME parallelize qkv
lora_output = torch.empty_like(base_output)
# q
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling
def forward(self, input_):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
def get_mapped_params(module_names):
ret = set()
for module_name in module_names:
ret.add(params_mapping(module_name))
return list(ret)
class LoRALayer(nn.Module):
def __init__(self, config, base_hf_config):
super().__init__()
self.config = config
self.base_hf_config = base_hf_config
self.weights = {}
self.weight_gpu = {}
def load_to_gpu(self):
for name, weight in self.weights.items():
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
def offload_from_gpu(self):
for name, weight in self.weights.items():
self.weight_gpu[name] = None
class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config):
super().__init__()
self.uid = uid
self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.scaling = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
]
)
self.weights = {}
self.weights_gpu = {}
def get_stacked_multiply(self, module_name):
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def load_to_gpu(self):
for name, weight in self.weights.items():
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
for layer in self.layers:
layer.load_to_gpu()
def offload_from_gpu(self):
for name, weight in self.weights.items():
self.weights_gpu[name] = None
for layer in self.layers:
layer.offload_from_gpu()
# initialize the LoRA weights to cpu
def initialize_weights(self):
model_path = self.config.path
loader = DefaultModelLoader(self.load_config)
revision = getattr(self.config.hf_config, "revision", None)
for name, loaded_weight in loader._get_weights_iterator(
model_path, revision=revision, fall_back_to_pt=True
):
match = re.search(r"layers\.(\d+)\.", name)
if match is not None:
layer_id = int(match.group(1))
self.layers[layer_id].weights[name] = loaded_weight.cpu()
else:
self.weights[name] = loaded_weight.cpu()
# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
for weight_name in weight_names:
if "k_proj" in weight_name:
q_name = weight_name.replace("k_proj", "q_proj")
v_name = weight_name.replace("k_proj", "v_proj")
kv_name = weight_name.replace("k_proj", "kv_proj")
qkv_name = weight_name.replace("k_proj", "qkv_proj")
if "lora_A" in weight_name:
layer.weights[qkv_name] = torch.cat(
(
layer.weights[q_name],
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(q_name)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import os
from huggingface_hub import snapshot_download
class LoRAConfig:
def __init__(
self,
path: str,
) -> None:
self.path = path
self.hf_config = self.get_lora_config()
self.target_modules = self.hf_config["target_modules"]
self.r = self.hf_config["r"]
self.lora_alpha = self.hf_config["lora_alpha"]
def get_lora_config(self, dummy=False):
if dummy:
raise NotImplementedError()
else:
if not os.path.isdir(self.path):
weights_dir = snapshot_download(self.path, allow_patterns=["*.json"])
else:
weights_dir = self.path
config_name = "adapter_config.json"
with open(os.path.join(weights_dir, config_name), "r") as f:
return json.load(f)
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
import re
from dataclasses import dataclass
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import replace_submodule
def get_stacked_name(name):
# origin name -> (name for A, name for B)
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
class LoRAManager:
def __init__(
self,
base_model,
lora_paths,
base_hf_config,
max_loras_per_batch,
load_config,
dtype,
):
self.base_model = base_model
self.lora_paths = lora_paths
self.base_hf_config = base_hf_config
self.max_loras_per_batch = max_loras_per_batch
self.load_config = load_config
self.dtype = dtype
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
self.init_loras()
self.init_lora_memory_pool()
self.init_lora_batch()
def match_target_modules(self, module_name):
for target_module in self.target_modules:
if module_name.split(".")[-1] == target_module:
return True
return False
def get_target_modules(self):
modules = []
for module_name, module in self.base_model.named_modules():
if self.match_target_modules(module_name):
modules.append((module_name, module))
return modules
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.segment_gemm, self.max_lora_dim, self.scaling
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def init_loras(self):
# get configs and target modules
self.configs = {}
self.origin_target_modules = set()
for path in self.lora_paths:
self.configs[path] = LoRAConfig(path)
self.origin_target_modules = set(self.origin_target_modules) | set(
self.configs[path].target_modules
)
self.target_modules = set(
[
self.base_model.get_module_name(module)
for module in self.origin_target_modules
]
)
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
)
# load all weights to cpu
self.loras = []
self.lora_id = {}
for path in self.lora_paths:
self.lora_id[path] = len(self.loras)
self.loras.append(
LoRAAdapter(
path, self.configs[path], self.base_hf_config, self.load_config
)
)
self.loras[-1].initialize_weights()
# misc lora configs
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
self.scaling = self.loras[0].scaling
# FIXME remove the restrictions
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
assert all(x.scaling == self.scaling for x in self.loras)
# monkey patch to use the LoRA version
self.lora_modules = []
for module_name, module in self.get_target_modules():
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)
def init_lora_memory_pool(self):
# preallocate lora memory pool
self.A_buffer = {}
self.B_buffer = {}
num_layer = self.base_hf_config.num_hidden_layers
for module_A, module_B in self.target_weights:
# init A tensor, column_major=True
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
c = self.loras[-1].get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
hidden_dim_A,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
# init B tensor, column_major=True
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
c = self.loras[-1].get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
self.max_loras_per_batch,
hidden_dim_B * c,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
def init_lora_batch(self):
self.active_uids = set() # set of active loras
self.buffer_id = {} # lora uid -> idx in memory pool
def get_weight_name(self, name, idx):
for target_weight_name in self.target_weights:
if target_weight_name[idx] in name:
return target_weight_name[idx]
def load_lora(self, uid, buffer_id):
num_layer = self.base_hf_config.num_hidden_layers
if uid is None:
for i in range(num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
for i in range(num_layer):
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = self.get_weight_name(name, 0)
if lora_weight_name:
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, batch, extend_seq_lens=None):
# load active loras into lora memory pool
cur_uids = set([req.lora_path for req in batch.reqs])
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
for uid in cur_uids:
if uid not in self.active_uids:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
if i < len(evictable_uids):
self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i])
self.load_lora(uid, i)
self.active_uids.add(uid)
self.buffer_id[uid] = i
i += 1
if cur_uids == set([None]):
return
# setup lora in forward modules
bs = len(batch.reqs)
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, req in enumerate(batch.reqs):
weight_indices[i] = self.buffer_id[req.lora_path]
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
if "qkv_proj" not in module_name:
weight_name = self.get_weight_name(module_name, 0)
module.set_lora_info(
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
bs,
seg_lens,
weight_indices,
)
else:
module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
bs,
seg_lens,
weight_indices,
)
......@@ -55,6 +55,9 @@ class GenerateReqInput:
is_single: bool = True
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
......@@ -184,6 +187,9 @@ class TokenizedGenerateReqInput:
# Modalities of the input images
modalites: Optional[List[str]] = None
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
@dataclass
class EmbeddingReqInput:
......
......@@ -98,7 +98,7 @@ class FINISH_ABORT(BaseFinishReason):
class Req:
"""Store all inforamtion of a request."""
def __init__(self, rid, origin_input_text, origin_input_ids):
def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
# Input and output info
self.rid = rid
self.origin_input_text = origin_input_text
......@@ -106,6 +106,7 @@ class Req:
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.lora_path = lora_path
# Memory info
self.req_pool_idx = None
......
......@@ -266,6 +266,11 @@ class TokenizerManager:
top_logprobs_num,
obj.stream,
modalities,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
else obj.lora_path
),
)
else: # is embedding
tokenized_obj = TokenizedEmbeddingReqInput(
......@@ -364,6 +369,11 @@ class TokenizerManager:
obj.top_logprobs_num[index],
obj.stream,
modalities,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
else obj.lora_path
),
)
else:
tokenized_obj = TokenizedEmbeddingReqInput(
......
......@@ -87,6 +87,8 @@ class ModelTpServer:
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
# Init model and tokenizer
self.model_config = ModelConfig(
......@@ -323,7 +325,15 @@ class ModelTpServer:
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
if isinstance(recv_req, TokenizedGenerateReqInput):
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
lora_path=recv_req.lora_path,
)
else:
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
req.pixel_values = recv_req.pixel_values
......@@ -442,10 +452,27 @@ class ModelTpServer:
self.current_inflight_req
)
if self.lora_paths is not None:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)
for req in self.waiting_queue:
if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
if (
self.lora_paths is not None
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
| set([req.lora_path])
)
> self.max_loras_per_batch
):
break
res = adder.add_one_req(req)
if (
not res
......
......@@ -41,6 +41,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
......@@ -107,6 +108,8 @@ class ModelRunner:
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_running_requests,
......@@ -312,6 +315,17 @@ class ModelRunner:
logger.info("Update weights end.")
return True, "Succeeded to update model weights"
def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,
lora_paths=self.server_args.lora_paths,
base_hf_config=self.model_config.hf_config,
max_loras_per_batch=self.server_args.max_loras_per_batch,
load_config=self.load_config,
dtype=self.dtype,
)
logger.info("LoRA manager ready.")
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
......@@ -450,6 +464,8 @@ class ModelRunner:
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch)
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
......@@ -466,6 +482,9 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
if self.is_generation:
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......
......@@ -324,6 +324,51 @@ class LlamaForCausalLM(nn.Module):
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -611,6 +611,7 @@ class Runtime:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
......@@ -618,7 +619,9 @@ class Runtime:
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
......
......@@ -101,6 +101,10 @@ class ServerArgs:
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
......@@ -522,6 +526,21 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# LoRA options
parser.add_argument(
"--lora-paths",
type=str,
nargs="*",
default=None,
help="The list of LoRA adapters.",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
......@@ -539,6 +558,12 @@ class ServerArgs:
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
def prepare_server_args(argv: List[str]) -> ServerArgs:
......
......@@ -35,6 +35,7 @@ import torch
import torch.distributed as dist
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from torch import nn
from torch.nn.parameter import Parameter
from triton.runtime.cache import (
FileCacheManager,
......@@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""):
datefmt="%H:%M:%S",
force=True,
)
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
def replace_submodule(
model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
......@@ -21,6 +21,7 @@ from typing import List, Union
import torch
import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
......@@ -52,6 +53,7 @@ def get_dtype_str(torch_dtype):
def get_top_logprobs(logits, k):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
del logits
logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
return logprobs
......@@ -71,8 +73,10 @@ class HFRunner:
model_path,
torch_dtype,
is_generation,
output_str_only=False,
):
self.is_generation = is_generation
self.output_str_only = output_str_only
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
......@@ -95,7 +99,7 @@ class HFRunner:
)
if self.is_generation:
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
......@@ -110,13 +114,16 @@ class HFRunner:
)
while True:
prompts, max_new_tokens = in_queue.get()
prompts, max_new_tokens, lora_paths = in_queue.get()
if lora_paths is not None:
assert len(prompts) == len(lora_paths)
if prompts is not None:
if self.is_generation:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for p in prompts:
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
......@@ -124,6 +131,16 @@ class HFRunner:
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
self.model = self.base_model
outputs = self.model.generate(
input_ids,
do_sample=False,
......@@ -131,25 +148,30 @@ class HFRunner:
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
output_scores=(not self.output_str_only),
)
output_strs.append(
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
)
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
del input_logits
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(
logits[0], NUM_TOP_LOGPROBS
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(
input_logits, NUM_TOP_LOGPROBS
).tolist()
)
del input_logits
out_queue.put(
ModelOutput(
......@@ -160,6 +182,7 @@ class HFRunner:
)
else:
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
......@@ -167,8 +190,9 @@ class HFRunner:
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
self.in_queue.put((prompts, max_new_tokens))
self.in_queue.put((prompts, max_new_tokens, lora_paths))
return self.out_queue.get()
def terminate(self):
......@@ -191,6 +215,10 @@ class SRTRunner:
is_generation,
tp_size=1,
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths=None,
max_loras_per_batch=4,
disable_cuda_graph=False,
disable_radix_cache=False,
):
self.is_generation = is_generation
self.runtime = Runtime(
......@@ -201,12 +229,17 @@ class SRTRunner:
mem_fraction_static=0.69,
trust_remote_code=False,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
if self.is_generation:
# the return value contains logprobs from prefill
......@@ -214,9 +247,10 @@ class SRTRunner:
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
for i, prompt in enumerate(prompts):
response = self.runtime.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
......@@ -256,6 +290,37 @@ class SRTRunner:
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
def batch_forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
):
"""
testing serving by sending all prompts once
only return output strings and no logprobs
"""
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.runtime.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
response = json.loads(response)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
else:
response = self.runtime.encode(prompts)
response = json.loads(response)
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
def __enter__(self):
return self
......
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
ADAPTER = "/home/ying/test_lora"
HF_TOKEN = "..."
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
base_model = LlamaForCausalLM.from_pretrained(
MODEL,
device_map="auto",
# load_in_8bit=True,
torch_dtype=torch.float16,
# use_auth_token=HF_TOKEN,
).cuda()
# base model generate
with torch.no_grad():
output_tensors = base_model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= base output ========")
print(output)
# peft model generate
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
torch_dtype=torch.float16,
is_trainable=False,
)
with torch.no_grad():
output_tensors = model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= peft output ========")
print(output)
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER = "/home/ying/test_lora"
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
llm = LLM(model=MODEL, enable_lora=True)
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)
prompts = [prompt]
outputs = llm.generate(
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
)
print(outputs[0].prompt)
print(outputs[0].outputs[0].text)
import json
import openai
import requests
import sglang as sgl
lora_path = "/home/ying/test_lora"
prompt_file = "/home/ying/test_prompt/dialogue_choice_prompts.json"
server_url = "http://127.0.0.1:30000"
client = openai.Client(base_url=server_url + "/v1", api_key="EMPTY")
# @sgl.function
# def generate(s, prompt):
# s += prompt
# s += sgl.gen("ans")
# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url))
def generate(prompt, lora_path):
json_data = {
"text": prompt,
"sampling_params": {},
"return_logprob": False,
"logprob_start_len": None,
"top_logprobs_num": None,
"lora_path": lora_path,
}
response = requests.post(
server_url + "/generate",
json=json_data,
)
return json.dumps(response.json())
with open(prompt_file, "r") as f:
samples = json.load(f)
for sample in samples[:1]:
assert sample[0]["role"] == "user"
prompt = sample[0]["content"]
assert sample[1]["role"] == "assistant"
ref = sample[1]["content"]
state = generate(prompt, lora_path)
print("================================")
print(ref)
print("--------------------------------")
# print(state["ans"])
print(state)
print()
"""
used for debug using tensor comparison
dump {name: tensor} into "log_hf.jsonl" and "log_srt.jsonl"
use the same name for two tensors that supposed to be close
recommend name like: "layer 2 after mlp"
"""
import json
import sys
import torch
if len(sys.argv) > 1:
assert sys.argv[1] == "base"
hf_log = "base_log_hf.jsonl"
srt_log = "base_log_srt.jsonl"
else:
hf_log = "log_hf.jsonl"
srt_log = "log_srt.jsonl"
def load_data(filepath):
tensors = {}
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
for k, v in data.items():
tensors[k] = torch.tensor(v)
return tensors
hf_tensors = load_data(hf_log)
srt_tensors = load_data(srt_log)
def get_diff(t1, t2):
t1 = t1.reshape(t2.shape)
max_diff = torch.max(abs(t1.reshape(t2.shape) - t2))
l2_dis = torch.dist(t1, t2, p=2)
return l2_dis, max_diff
for k, _ in srt_tensors.items():
l2_dis, max_diff = get_diff(hf_tensors[k], srt_tensors[k])
print(f"{k} {l2_dis=} {max_diff=}")
if k == "layer 1 attn":
print(hf_tensors[k])
print(srt_tensors[k])
if k == "layer 0 prefill k":
print(srt_tensors[k].shape)
print(hf_tensors[k].shape)
......@@ -76,6 +76,7 @@ class TestGenerationModels(unittest.TestCase):
) -> None:
if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
prompts = prompts[:-1]
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation=True
) as hf_runner:
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import multiprocessing as mp
import unittest
import uuid
import torch
from sglang.test.runners import HFRunner, SRTRunner
LORA_SETS = [
# {
# "base": "meta-llama/Llama-2-7b-hf",
# "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"],
# },
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
# {
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
# "loras": [
# "/home/ying/test_lora",
# "/home/ying/test_lora_1",
# "/home/ying/test_lora_2",
# "/home/ying/test_lora_3",
# "/home/ying/test_lora_4",
# ],
# },
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]},
]
TORCH_DTYPES = [torch.float16]
PROMPTS = [
"""
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
""",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
]
# import json
#
# with open("/home/ying/test_prompt/dialogue_choice_prompts.json", "r") as f:
# samples = json.load(f)
# for sample in samples[:5]:
# assert sample[0]["role"] == "user"
# PROMPTS.append(sample[0]["content"][:2000])
class TestLoRA(unittest.TestCase):
def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing inference =======================")
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = [None]
i = 0
for _ in range(len(prompts) - 1):
batch_lora_paths.append(all_lora_paths[i])
i = (i + 1) % len(all_lora_paths)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
lora_paths=all_lora_paths,
max_loras_per_batch=3,
disable_cuda_graph=True,
disable_radix_cache=True,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
for i in range(len(prompts)):
# compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i])
srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i]
)
print(
"max input diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and srt_lora",
torch.max(abs(srt_no_lora_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and hf_base",
torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)),
)
print(
"max input diff between hf_lora and hf_base",
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
)
# compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "\noutput logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print(
"max output diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
"\n",
)
# compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
print(f"{hf_no_lora_outputs.output_strs=}")
print(f"{srt_no_lora_outputs.output_strs=}")
for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
str_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i],
)
# assert (
# srt_no_lora_outputs.output_strs[i].strip(" ")
# == hf_no_lora_outputs.output_strs[i]
# ), (
# srt_no_lora_outputs.output_strs[i].strip(" "),
# hf_no_lora_outputs.output_strs[i],
# )
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing serving =======================")
# test batch forward
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = [None]
i = 0
for _ in range(len(prompts) - 1):
batch_lora_paths.append(all_lora_paths[i])
i = (i + 1) % len(all_lora_paths)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
lora_paths=all_lora_paths,
max_loras_per_batch=3,
disable_cuda_graph=True,
disable_radix_cache=True,
) as srt_runner:
srt_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
is_generation=True,
output_str_only=True,
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
# compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i],
)
def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing base inference =======================")
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = [None] * len(prompts)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation=True,
lora_paths=all_lora_paths,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
for i in range(len(prompts)):
srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i]
)
srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i])
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
print(f"{srt_no_lora_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
str_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i],
)
assert (
srt_no_lora_outputs[i].output_strs.strip(" ")
== hf_no_lora_outputs[i].output_strs
)
def test_all(self):
for lora_set in LORA_SETS:
# self.load_lora_adapter(lora_set, 1)
for torch_dtype in TORCH_DTYPES:
tp_size = 1
max_new_tokens = 32
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.base_inference(
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# )
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment