Unverified Commit 9c064bf7 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[LoRA, Performance] Speedup multi-LoRA serving - Step 1 (#1587)

parent 58d1082e
import argparse import argparse
import os import os
NUM_LORAS = 128 NUM_LORAS = 8
LORA_PATH = { LORA_PATH = {
"base": "mistralai/Mistral-7B-Instruct-v0.3", "base": "mistralai/Mistral-7B-Instruct-v0.3",
"lora": "/home/ying/test_lora", "lora": "/home/ying/test_lora",
...@@ -11,12 +11,11 @@ LORA_PATH = { ...@@ -11,12 +11,11 @@ LORA_PATH = {
def launch_server(args): def launch_server(args):
base_path = LORA_PATH["base"] base_path = LORA_PATH["base"]
lora_path = LORA_PATH["lora"] lora_path = LORA_PATH["lora"]
max_loras_per_batch = 4
if args.base_only: if args.base_only:
cmd = f"python -m sglang.launch_server --model {base_path} " cmd = f"python3 -m sglang.launch_server --model {base_path} "
else: else:
cmd = f"python -m sglang.launch_server --model {base_path} --lora-paths " cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
for i in range(NUM_LORAS): for i in range(NUM_LORAS):
lora_name = f"lora{i}" lora_name = f"lora{i}"
cmd += f"{lora_name}={lora_path} " cmd += f"{lora_name}={lora_path} "
...@@ -29,11 +28,6 @@ def launch_server(args): ...@@ -29,11 +28,6 @@ def launch_server(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--num-loras",
type=int,
default=128,
)
parser.add_argument( parser.add_argument(
"--base-only", "--base-only",
action="store_true", action="store_true",
......
...@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True self.set_lora = True
self.A_buffer = A_buffer self.A_buffer = A_buffer
self.B_buffer = B_buffer self.B_buffer = B_buffer
self.bs = bs self.bs = bs
self.seq_lens = seq_lens self.seg_indptr = seg_indptr
self.weight_indices = weight_indices self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
...@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.A_buffer, weights=self.A_buffer,
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
# FIXME # FIXME
assert lora_a_output.shape[-1] == self.lora_rank * 2
lora_output = torch.empty_like(base_output) lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2 output_dim = lora_output.shape[-1] // 2
for i in range(2): for i in range(2):
...@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.B_buffer[:, left:right, :].contiguous(), weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
return base_output + lora_output * self.scaling return base_output + lora_output * self.scaling
...@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info( def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
): ):
self.set_lora = True self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv self.B_buffer_kv = B_buffer_kv
self.bs = bs self.bs = bs
self.seq_lens = seq_lens self.seg_indptr = seg_indptr
self.weight_indices = weight_indices self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
...@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.A_buffer_qkv, weights=self.A_buffer_qkv,
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
# FIXME parallelize qkv # FIXME parallelize qkv
...@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.B_buffer_q, weights=self.B_buffer_q,
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
# kv # kv
...@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights=self.B_buffer_kv[:, left:right, :].contiguous(), weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
) )
...@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, segment_gemm, lora_rank, scaling)
def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
self.set_lora = True self.set_lora = True
self.A_buffer = A_buffer self.A_buffer = A_buffer
self.B_buffer = B_buffer self.B_buffer = B_buffer
self.bs = bs self.bs = bs
self.seq_lens = seq_lens self.seg_indptr = seg_indptr
self.weight_indices = weight_indices self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
...@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
weights=self.A_buffer, weights=self.A_buffer,
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
lora_output = self.segment_gemm.run( lora_output = self.segment_gemm.run(
...@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
weights=self.B_buffer, weights=self.B_buffer,
batch_size=self.bs, batch_size=self.bs,
weight_column_major=True, weight_column_major=True,
seg_lens=self.seq_lens, seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices, weight_indices=self.weight_indices,
) )
return base_output + lora_output * self.scaling return base_output + lora_output * self.scaling
......
...@@ -274,18 +274,24 @@ class LoRAManager: ...@@ -274,18 +274,24 @@ class LoRAManager:
cur_uids = set(forward_batch.lora_paths) cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
i = 0 i = 0
j = len(self.active_uids)
evictable_uids = list(self.active_uids) evictable_uids = list(self.active_uids)
for uid in cur_uids: for uid in cur_uids:
if uid not in self.active_uids: if uid not in self.active_uids:
if j < self.max_loras_per_batch:
index = j
j += 1
else:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids: while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1 i += 1
if i < len(evictable_uids): assert i < len(evictable_uids)
self.active_uids.remove(evictable_uids[i]) self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i]) self.buffer_id.pop(evictable_uids[i])
self.load_lora(uid, i) index = i
self.active_uids.add(uid)
self.buffer_id[uid] = i
i += 1 i += 1
self.load_lora(uid, index)
self.active_uids.add(uid)
self.buffer_id[uid] = index
if cur_uids == set([None]): if cur_uids == set([None]):
return return
...@@ -295,8 +301,11 @@ class LoRAManager: ...@@ -295,8 +301,11 @@ class LoRAManager:
seg_lens = ( seg_lens = (
forward_batch.extend_seq_lens forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend() if forward_batch.forward_mode.is_extend()
else torch.ones(bs) else torch.ones(bs, device="cuda")
) )
# FIXME: reuse the data rather than recompute
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths): for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path] weight_indices[i] = self.buffer_id[lora_path]
...@@ -310,7 +319,7 @@ class LoRAManager: ...@@ -310,7 +319,7 @@ class LoRAManager:
self.A_buffer[weight_name][layer_id], self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id],
bs, bs,
seg_lens, seg_indptr,
weight_indices, weight_indices,
) )
else: else:
...@@ -319,6 +328,6 @@ class LoRAManager: ...@@ -319,6 +328,6 @@ class LoRAManager:
self.B_buffer["q_proj"][layer_id], self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id], self.B_buffer["kv_proj"][layer_id],
bs, bs,
seg_lens, seg_indptr,
weight_indices, weight_indices,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment