Unverified Commit 021f76e4 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[Perf] Refactor LoRAManager to eliminate stream syncs and redundant computations (#6994)

parent 777688b8
...@@ -81,7 +81,7 @@ class LoRAManager: ...@@ -81,7 +81,7 @@ class LoRAManager:
seg_indptr=torch.zeros( seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32 self.max_bs_in_cuda_graph + 1, dtype=torch.int32
), ),
max_len=0, max_len=1,
weight_indices=torch.zeros( weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32 self.max_bs_in_cuda_graph, dtype=torch.int32
), ),
...@@ -89,6 +89,17 @@ class LoRAManager: ...@@ -89,6 +89,17 @@ class LoRAManager:
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
) )
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[
1 : self.max_bs_in_cuda_graph + 1
],
)
def init_loras(self): def init_loras(self):
# Config of each LoRA adapter # Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {} self.configs: Dict[str, LoRAConfig] = {}
...@@ -159,6 +170,45 @@ class LoRAManager: ...@@ -159,6 +170,45 @@ class LoRAManager:
# set up batch info shared by all lora modules # set up batch info shared by all lora modules
bs = forward_batch.batch_size bs = forward_batch.batch_size
def transfer_adapter_info(
weight_indices_out: torch.Tensor,
lora_ranks_out: torch.Tensor,
scalings_out: torch.Tensor,
):
"""
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices = [0] * len(forward_batch.lora_paths)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
# Copy to device tensors asynchronously
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
lora_ranks_out[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
scalings_out[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
if ( if (
hasattr(self, "max_bs_in_cuda_graph") hasattr(self, "max_bs_in_cuda_graph")
and bs <= self.max_bs_in_cuda_graph and bs <= self.max_bs_in_cuda_graph
...@@ -166,51 +216,46 @@ class LoRAManager: ...@@ -166,51 +216,46 @@ class LoRAManager:
): ):
# Do in-place updates when CUDA graph is enabled and the batch forward mode # Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph. # could use CUDA graph.
self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[:bs],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
)
self.cuda_graph_batch_info.max_len = 1
for i, lora_path in enumerate(forward_batch.lora_paths): transfer_adapter_info(
self.cuda_graph_batch_info.weight_indices[i] = ( self.cuda_graph_batch_info.weight_indices,
self.memory_pool.get_buffer_id(lora_path) self.cuda_graph_batch_info.lora_ranks,
self.cuda_graph_batch_info.scalings,
) )
if lora_path is not None:
lora = self.loras[lora_path] self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.lora_ranks[ self.cuda_graph_batch_info.max_len = 1
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
batch_info = self.cuda_graph_batch_info batch_info = self.cuda_graph_batch_info
else: else:
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
)
transfer_adapter_info(
weight_indices,
lora_ranks,
scalings,
)
seg_lens = ( seg_lens = (
forward_batch.extend_seq_lens forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend() if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device) else torch.ones(bs, device=self.device)
) )
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo( batch_info = LoRABatchInfo(
bs=bs, bs=bs,
seg_lens=seg_lens, seg_lens=seg_lens,
......
...@@ -132,12 +132,13 @@ class LoRAMemoryPool: ...@@ -132,12 +132,13 @@ class LoRAMemoryPool:
for buffer_id in range(self.max_loras_per_batch): for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots # Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "": if self.buffer_id_to_uid[buffer_id] == "":
return buffer_id, "" return buffer_id
for buffer_id in range(self.max_loras_per_batch): for buffer_id in range(self.max_loras_per_batch):
# Evict unneeded lora # Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids: if self.buffer_id_to_uid[buffer_id] not in cur_uids:
return buffer_id, self.buffer_id_to_uid[buffer_id] self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
return buffer_id
raise ValueError( raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch." "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
...@@ -145,9 +146,7 @@ class LoRAMemoryPool: ...@@ -145,9 +146,7 @@ class LoRAMemoryPool:
for uid in cur_uids: for uid in cur_uids:
if uid not in self.uid_to_buffer_id: if uid not in self.uid_to_buffer_id:
buffer_id, evicted_lora_uid = get_available_buffer_slot() buffer_id = get_available_buffer_slot()
if evicted_lora_uid != "":
self.uid_to_buffer_id.pop(evicted_lora_uid)
self.load_lora_weight_to_buffer( self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapters.get(uid, None) uid, buffer_id, lora_adapters.get(uid, None)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment