Unverified Commit f2d68ded authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Rename lora_path to lora_id in batches (#8437)

parent 3b87a9e8
...@@ -191,11 +191,7 @@ class LoRAManager: ...@@ -191,11 +191,7 @@ class LoRAManager:
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool # Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique cur_uids = set(forward_batch.lora_ids)
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
...@@ -211,10 +207,10 @@ class LoRAManager: ...@@ -211,10 +207,10 @@ class LoRAManager:
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously. to device (CUDA) asynchronously.
""" """
weight_indices = [0] * len(forward_batch.lora_paths) weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch scalings = [0] * self.max_loras_per_batch
for i, uid in enumerate(forward_batch.lora_paths): for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid) weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None: if uid is not None:
lora = self.loras[uid] lora = self.loras[uid]
......
...@@ -101,8 +101,10 @@ class GenerateReqInput: ...@@ -101,8 +101,10 @@ class GenerateReqInput:
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
# The path to the LoRA # The path to the LoRA adaptors
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# The uid of LoRA adaptors, should be initialized by tokenizer manager
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session info for continual prompting # Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None session_params: Optional[Union[List[Dict], Dict]] = None
...@@ -500,7 +502,7 @@ class TokenizedGenerateReqInput: ...@@ -500,7 +502,7 @@ class TokenizedGenerateReqInput:
stream: bool stream: bool
# LoRA related # LoRA related
lora_path: Optional[str] = None # None means just use the base model lora_id: Optional[str] = None # None means just use the base model
# The input embeds # The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
......
...@@ -423,7 +423,7 @@ class Req: ...@@ -423,7 +423,7 @@ class Req:
token_ids_logprob: List[int] = None, token_ids_logprob: List[int] = None,
stream: bool = False, stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None, origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None, lora_id: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
token_type_ids: List[int] = None, token_type_ids: List[int] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
...@@ -467,7 +467,7 @@ class Req: ...@@ -467,7 +467,7 @@ class Req:
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.lora_path = lora_path self.lora_id = lora_id
# Memory pool info # Memory pool info
self.req_pool_idx: Optional[int] = None self.req_pool_idx: Optional[int] = None
...@@ -1750,7 +1750,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1750,7 +1750,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
encoder_lens=self.encoder_lens, encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu, encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc, encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs], lora_ids=[req.lora_id for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
input_embeds=self.input_embeds, input_embeds=self.input_embeds,
token_type_ids=self.token_type_ids, token_type_ids=self.token_type_ids,
...@@ -1891,7 +1891,7 @@ class ModelWorkerBatch: ...@@ -1891,7 +1891,7 @@ class ModelWorkerBatch:
encoder_out_cache_loc: Optional[torch.Tensor] encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA # For LoRA
lora_paths: Optional[List[str]] lora_ids: Optional[List[str]]
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
......
...@@ -1090,7 +1090,7 @@ class Scheduler( ...@@ -1090,7 +1090,7 @@ class Scheduler(
top_logprobs_num=recv_req.top_logprobs_num, top_logprobs_num=recv_req.top_logprobs_num,
token_ids_logprob=recv_req.token_ids_logprob, token_ids_logprob=recv_req.token_ids_logprob,
stream=recv_req.stream, stream=recv_req.stream,
lora_path=recv_req.lora_path, lora_id=recv_req.lora_id,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
custom_logit_processor=recv_req.custom_logit_processor, custom_logit_processor=recv_req.custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states, return_hidden_states=recv_req.return_hidden_states,
...@@ -1534,7 +1534,7 @@ class Scheduler( ...@@ -1534,7 +1534,7 @@ class Scheduler(
self.chunked_req = adder.add_chunked_req(self.chunked_req) self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.enable_lora: if self.enable_lora:
lora_set = set([req.lora_path for req in self.running_batch.reqs]) lora_set = set([req.lora_id for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
...@@ -1542,8 +1542,8 @@ class Scheduler( ...@@ -1542,8 +1542,8 @@ class Scheduler(
self.enable_lora self.enable_lora
and len( and len(
lora_set lora_set
| set([req.lora_path for req in adder.can_run_list]) | set([req.lora_id for req in adder.can_run_list])
| set([req.lora_path]) | set([req.lora_id])
) )
> self.max_loras_per_batch > self.max_loras_per_batch
): ):
......
...@@ -556,7 +556,7 @@ class TokenizerManager: ...@@ -556,7 +556,7 @@ class TokenizerManager:
if self.server_args.enable_lora and obj.lora_path: if self.server_args.enable_lora and obj.lora_path:
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing. # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path) obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
...@@ -665,7 +665,7 @@ class TokenizerManager: ...@@ -665,7 +665,7 @@ class TokenizerManager:
bootstrap_host=obj.bootstrap_host, bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port, bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room, bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path, lora_id=obj.lora_id,
input_embeds=input_embeds, input_embeds=input_embeds,
session_params=session_params, session_params=session_params,
custom_logit_processor=obj.custom_logit_processor, custom_logit_processor=obj.custom_logit_processor,
...@@ -773,7 +773,7 @@ class TokenizerManager: ...@@ -773,7 +773,7 @@ class TokenizerManager:
# Mark ongoing LoRA request as finished. # Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and obj.lora_path: if self.server_args.enable_lora and obj.lora_path:
await self.lora_registry.release(obj.lora_path) await self.lora_registry.release(obj.lora_id)
# Check if this was an abort/error created by scheduler # Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict): if isinstance(out["meta_info"].get("finish_reason"), dict):
......
...@@ -576,11 +576,11 @@ class CudaGraphRunner: ...@@ -576,11 +576,11 @@ class CudaGraphRunner:
) )
if self.model_runner.server_args.enable_lora: if self.model_runner.server_args.enable_lora:
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization). # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
lora_paths = [None] * bs lora_ids = [None] * bs
else: else:
lora_paths = None lora_ids = None
forward_batch = ForwardBatch( forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode, forward_mode=self.capture_forward_mode,
...@@ -607,11 +607,11 @@ class CudaGraphRunner: ...@@ -607,11 +607,11 @@ class CudaGraphRunner:
capture_hidden_mode=self.capture_hidden_mode, capture_hidden_mode=self.capture_hidden_mode,
num_token_non_padded=self.num_token_non_padded, num_token_non_padded=self.num_token_non_padded,
global_forward_mode=self.capture_forward_mode, global_forward_mode=self.capture_forward_mode,
lora_paths=lora_paths, lora_ids=lora_ids,
) )
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
if lora_paths is not None: if lora_ids is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch) self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
# Attention backend # Attention backend
......
...@@ -248,7 +248,7 @@ class ForwardBatch: ...@@ -248,7 +248,7 @@ class ForwardBatch:
encoder_out_cache_loc: Optional[torch.Tensor] = None encoder_out_cache_loc: Optional[torch.Tensor] = None
# For LoRA # For LoRA
lora_paths: Optional[List[str]] = None lora_ids: Optional[List[str]] = None
# For input embeddings # For input embeddings
input_embeds: Optional[torch.Tensor] = None input_embeds: Optional[torch.Tensor] = None
...@@ -327,7 +327,7 @@ class ForwardBatch: ...@@ -327,7 +327,7 @@ class ForwardBatch:
is_extend_in_batch=batch.is_extend_in_batch, is_extend_in_batch=batch.is_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode, global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths, lora_ids=batch.lora_ids,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
......
...@@ -468,7 +468,7 @@ class TboForwardBatchPreparer: ...@@ -468,7 +468,7 @@ class TboForwardBatchPreparer:
"extend_prefix_lens_cpu", "extend_prefix_lens_cpu",
"extend_seq_lens_cpu", "extend_seq_lens_cpu",
"extend_logprob_start_lens_cpu", "extend_logprob_start_lens_cpu",
"lora_paths", "lora_ids",
]: ]:
old_value = getattr(batch, key) old_value = getattr(batch, key)
if old_value is None: if old_value is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment