Unverified Commit f2d68ded authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Rename lora_path to lora_id in batches (#8437)

parent 3b87a9e8
......@@ -191,11 +191,7 @@ class LoRAManager:
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids = set(forward_batch.lora_paths)
cur_uids = set(forward_batch.lora_ids)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
......@@ -211,10 +207,10 @@ class LoRAManager:
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices = [0] * len(forward_batch.lora_paths)
weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, uid in enumerate(forward_batch.lora_paths):
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
......
......@@ -101,8 +101,10 @@ class GenerateReqInput:
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
# The path to the LoRA
# The path to the LoRA adaptors
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# The uid of LoRA adaptors, should be initialized by tokenizer manager
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
......@@ -500,7 +502,7 @@ class TokenizedGenerateReqInput:
stream: bool
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
lora_id: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
......
......@@ -423,7 +423,7 @@ class Req:
token_ids_logprob: List[int] = None,
stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
lora_id: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
token_type_ids: List[int] = None,
session_id: Optional[str] = None,
......@@ -467,7 +467,7 @@ class Req:
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
self.lora_path = lora_path
self.lora_id = lora_id
# Memory pool info
self.req_pool_idx: Optional[int] = None
......@@ -1750,7 +1750,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
lora_ids=[req.lora_id for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
token_type_ids=self.token_type_ids,
......@@ -1891,7 +1891,7 @@ class ModelWorkerBatch:
encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA
lora_paths: Optional[List[str]]
lora_ids: Optional[List[str]]
# Sampling info
sampling_info: SamplingBatchInfo
......
......@@ -1090,7 +1090,7 @@ class Scheduler(
top_logprobs_num=recv_req.top_logprobs_num,
token_ids_logprob=recv_req.token_ids_logprob,
stream=recv_req.stream,
lora_path=recv_req.lora_path,
lora_id=recv_req.lora_id,
input_embeds=recv_req.input_embeds,
custom_logit_processor=recv_req.custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
......@@ -1534,7 +1534,7 @@ class Scheduler(
self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.enable_lora:
lora_set = set([req.lora_path for req in self.running_batch.reqs])
lora_set = set([req.lora_id for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
......@@ -1542,8 +1542,8 @@ class Scheduler(
self.enable_lora
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
| set([req.lora_path])
| set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id])
)
> self.max_loras_per_batch
):
......
......@@ -556,7 +556,7 @@ class TokenizerManager:
if self.server_args.enable_lora and obj.lora_path:
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
......@@ -665,7 +665,7 @@ class TokenizerManager:
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path,
lora_id=obj.lora_id,
input_embeds=input_embeds,
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
......@@ -773,7 +773,7 @@ class TokenizerManager:
# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_path)
await self.lora_registry.release(obj.lora_id)
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
......
......@@ -576,11 +576,11 @@ class CudaGraphRunner:
)
if self.model_runner.server_args.enable_lora:
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
lora_paths = [None] * bs
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
lora_ids = [None] * bs
else:
lora_paths = None
lora_ids = None
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
......@@ -607,11 +607,11 @@ class CudaGraphRunner:
capture_hidden_mode=self.capture_hidden_mode,
num_token_non_padded=self.num_token_non_padded,
global_forward_mode=self.capture_forward_mode,
lora_paths=lora_paths,
lora_ids=lora_ids,
)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
if lora_paths is not None:
if lora_ids is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
# Attention backend
......
......@@ -248,7 +248,7 @@ class ForwardBatch:
encoder_out_cache_loc: Optional[torch.Tensor] = None
# For LoRA
lora_paths: Optional[List[str]] = None
lora_ids: Optional[List[str]] = None
# For input embeddings
input_embeds: Optional[torch.Tensor] = None
......@@ -327,7 +327,7 @@ class ForwardBatch:
is_extend_in_batch=batch.is_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths,
lora_ids=batch.lora_ids,
sampling_info=batch.sampling_info,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
......
......@@ -468,7 +468,7 @@ class TboForwardBatchPreparer:
"extend_prefix_lens_cpu",
"extend_seq_lens_cpu",
"extend_logprob_start_lens_cpu",
"lora_paths",
"lora_ids",
]:
old_value = getattr(batch, key)
if old_value is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment