Unverified Commit cc7b04a2 authored by Chenxi Li's avatar Chenxi Li Committed by GitHub
Browse files

Feature/Add GET endpoint to query loaded LoRA adapters (#12229)

parent d85d6dba
...@@ -1168,6 +1168,8 @@ async def available_models(): ...@@ -1168,6 +1168,8 @@ async def available_models():
"""Show available models. OpenAI-compatible endpoint.""" """Show available models. OpenAI-compatible endpoint."""
served_model_names = [_global_state.tokenizer_manager.served_model_name] served_model_names = [_global_state.tokenizer_manager.served_model_name]
model_cards = [] model_cards = []
# Add base model
for served_model_name in served_model_names: for served_model_name in served_model_names:
model_cards.append( model_cards.append(
ModelCard( ModelCard(
...@@ -1176,6 +1178,20 @@ async def available_models(): ...@@ -1176,6 +1178,20 @@ async def available_models():
max_model_len=_global_state.tokenizer_manager.model_config.context_len, max_model_len=_global_state.tokenizer_manager.model_config.context_len,
) )
) )
# Add loaded LoRA adapters
if _global_state.tokenizer_manager.server_args.enable_lora:
lora_registry = _global_state.tokenizer_manager.lora_registry
for _, lora_ref in lora_registry.get_all_adapters().items():
model_cards.append(
ModelCard(
id=lora_ref.lora_name,
root=lora_ref.lora_path,
parent=served_model_names[0],
max_model_len=None,
)
)
return ModelList(data=model_cards) return ModelList(data=model_cards)
......
...@@ -54,6 +54,7 @@ class ModelCard(BaseModel): ...@@ -54,6 +54,7 @@ class ModelCard(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "sglang" owned_by: str = "sglang"
root: Optional[str] = None root: Optional[str] = None
parent: Optional[str] = None
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
......
...@@ -205,3 +205,12 @@ class LoRARegistry: ...@@ -205,3 +205,12 @@ class LoRARegistry:
Returns the total number of LoRA adapters currently registered. Returns the total number of LoRA adapters currently registered.
""" """
return len(self._registry) return len(self._registry)
def get_all_adapters(self) -> Dict[str, LoRARef]:
"""
Returns a dictionary of all registered LoRA adapters.
Returns:
Dict[str, LoRARef]: A dictionary mapping LoRA names to LoRARef objects.
"""
return dict(self._registry)
...@@ -1328,6 +1328,78 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -1328,6 +1328,78 @@ class TestLoRADynamicUpdate(CustomTestCase):
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
) )
def test_v1_models_endpoint_with_lora(self):
"""
Test that /v1/models endpoint returns base model and loaded LoRA adapters.
"""
adapters = [
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
]
with LoRAUpdateTestSession(
testcase=self,
mode=LoRAUpdateTestSessionMode.SERVER,
model_path="meta-llama/Llama-3.1-8B-Instruct",
lora_paths=[],
max_loras_per_batch=2,
max_lora_rank=256,
lora_target_modules=["all"],
enable_lora=True,
) as session:
# Test with no adapters loaded
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)
models_data = response.json()
self.assertEqual(models_data["object"], "list")
self.assertEqual(len(models_data["data"]), 1) # Only base model
base_model = models_data["data"][0]
self.assertIn("meta-llama", base_model["id"].lower())
self.assertIsNone(base_model.get("parent"))
# Load first adapter
session.load_lora_adapter(lora_name="adapter1", lora_path=adapters[0])
# Test with one adapter loaded
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)
models_data = response.json()
self.assertEqual(len(models_data["data"]), 2) # Base model + 1 adapter
# Verify adapter information
adapter_models = [m for m in models_data["data"] if m.get("parent")]
self.assertEqual(len(adapter_models), 1)
self.assertEqual(adapter_models[0]["id"], "adapter1")
self.assertEqual(adapter_models[0]["root"], adapters[0])
self.assertIsNotNone(adapter_models[0]["parent"])
# Load second adapter
session.load_lora_adapter(lora_name="adapter2", lora_path=adapters[1])
# Test with two adapters loaded
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)
models_data = response.json()
self.assertEqual(len(models_data["data"]), 3) # Base model + 2 adapters
# Verify both adapters are listed
adapter_models = [m for m in models_data["data"] if m.get("parent")]
self.assertEqual(len(adapter_models), 2)
adapter_names = {m["id"] for m in adapter_models}
self.assertEqual(adapter_names, {"adapter1", "adapter2"})
# Unload one adapter
session.unload_lora_adapter(lora_name="adapter1")
# Test after unloading
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)
models_data = response.json()
self.assertEqual(len(models_data["data"]), 2) # Base model + 1 adapter
adapter_models = [m for m in models_data["data"] if m.get("parent")]
self.assertEqual(len(adapter_models), 1)
self.assertEqual(adapter_models[0]["id"], "adapter2")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment