Unverified Commit 038e9be4 authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

[LoRA] Much faster startup when LoRA is enabled (#23777)


Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 68a34911
...@@ -2213,6 +2213,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2213,6 +2213,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: bool = False, uniform_decode: bool = False,
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
remove_lora: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Run a dummy forward pass to warm up/profile run or capture the Run a dummy forward pass to warm up/profile run or capture the
...@@ -2230,6 +2231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2230,6 +2231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: If True, the batch is a uniform decode batch. uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update. skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run. is_profile: If True, this is a profile run.
remove_lora: If False, dummy LoRAs are not destroyed after the run
""" """
assert cudagraph_runtime_mode in { assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
...@@ -2317,7 +2319,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2317,7 +2319,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs: if self.supports_mm_inputs:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
...@@ -2708,11 +2710,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2708,11 +2710,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention, force_attention=force_attention,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
skip_eplb=True) skip_eplb=True,
remove_lora=False)
self._dummy_run(num_tokens, self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
skip_eplb=True) skip_eplb=True,
remove_lora=False)
self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
""" """
......
...@@ -308,7 +308,10 @@ class Worker(WorkerBase): ...@@ -308,7 +308,10 @@ class Worker(WorkerBase):
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True): for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size) logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size, skip_eplb=True) self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
# Warmup and tune the kernels used during model execution before # Warmup and tune the kernels used during model execution before
# cuda graph capture. # cuda graph capture.
......
...@@ -5,7 +5,7 @@ Define LoRA functionality mixin for model runners. ...@@ -5,7 +5,7 @@ Define LoRA functionality mixin for model runners.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Union from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -87,7 +87,9 @@ class LoRAModelRunnerMixin: ...@@ -87,7 +87,9 @@ class LoRAModelRunnerMixin:
lora_requests) lora_requests)
@contextmanager @contextmanager
def maybe_setup_dummy_loras(self, lora_config): def maybe_setup_dummy_loras(self,
lora_config: Optional[LoRAConfig],
remove_lora: bool = True):
if lora_config is None: if lora_config is None:
yield yield
else: else:
...@@ -114,10 +116,11 @@ class LoRAModelRunnerMixin: ...@@ -114,10 +116,11 @@ class LoRAModelRunnerMixin:
yield yield
# __exit__ code # __exit__ code
if remove_lora:
self.lora_manager.remove_all_adapters() self.lora_manager.remove_all_adapters()
@contextmanager @contextmanager
def maybe_select_dummy_loras(self, lora_config: LoRAConfig, def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig],
num_scheduled_tokens: np.ndarray): num_scheduled_tokens: np.ndarray):
if lora_config is None: if lora_config is None:
yield yield
...@@ -151,13 +154,22 @@ class LoRAModelRunnerMixin: ...@@ -151,13 +154,22 @@ class LoRAModelRunnerMixin:
yield yield
@contextmanager @contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, def maybe_dummy_run_with_lora(self,
num_scheduled_tokens: np.ndarray): lora_config: Optional[LoRAConfig],
with self.maybe_setup_dummy_loras( num_scheduled_tokens: np.ndarray,
lora_config), self.maybe_select_dummy_loras( remove_lora: bool = True):
lora_config, num_scheduled_tokens): with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(lora_config,
num_scheduled_tokens),
):
yield yield
def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]):
if lora_config is None:
return
self.lora_manager.remove_all_adapters()
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment