{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LoRA Serving" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SGLang enables the use of [LoRA adapters](https://arxiv.org/abs/2106.09685) with a base model. By incorporating techniques from [S-LoRA](https://arxiv.org/pdf/2311.03285) and [Punica](https://arxiv.org/pdf/2310.18547), SGLang can efficiently support multiple LoRA adapters for different sequences within a single batch of inputs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Arguments for LoRA Serving" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following server arguments are relevant for multi-LoRA serving:\n", "\n", "* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n", "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", "\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "\n", "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n", "\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "\n", "From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Usage\n", "\n", "### Serving Single Adaptor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sglang.test.test_utils import is_in_ci\n", "\n", "if is_in_ci():\n", " from patch import launch_server_cmd\n", "else:\n", " from sglang.utils import launch_server_cmd\n", "\n", "from sglang.utils import wait_for_server, terminate_process\n", "\n", "import json\n", "import requests" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --max-loras-per-batch 1 --lora-backend triton \\\n", " --disable-radix-cache\n", "\"\"\"\n", ")\n", "\n", "wait_for_server(f\"http://localhost:{port}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "url = f\"http://127.0.0.1:{port}\"\n", "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", " \"AI is a field of computer science focused on\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses the base model\n", " \"lora_path\": [\"lora0\", None],\n", "}\n", "response = requests.post(\n", " url + \"/generate\",\n", " json=json_data,\n", ")\n", "print(f\"Output 0: {response.json()[0]['text']}\")\n", "print(f\"Output 1: {response.json()[1]['text']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Serving Multiple Adaptors" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", " --disable-radix-cache\n", "\"\"\"\n", ")\n", "\n", "wait_for_server(f\"http://localhost:{port}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "url = f\"http://127.0.0.1:{port}\"\n", "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", " \"AI is a field of computer science focused on\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", " \"lora_path\": [\"lora0\", \"lora1\"],\n", "}\n", "response = requests.post(\n", " url + \"/generate\",\n", " json=json_data,\n", ")\n", "print(f\"Output 0: {response.json()[0]['text']}\")\n", "print(f\"Output 1: {response.json()[1]['text']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dynamic LoRA loading" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic Usage\n", "\n", "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n", "\n", "(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n", " --cuda-graph-max-bs 2 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", " --disable-radix-cache\n", " \"\"\"\n", ")\n", "\n", "url = f\"http://127.0.0.1:{port}\"\n", "wait_for_server(url)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", " \"lora_name\": \"lora1\",\n", " \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n", " },\n", ")\n", "\n", "if response.status_code == 200:\n", " print(\"LoRA adapter loaded successfully.\", response.json())\n", "else:\n", " print(\"Failed to load LoRA adapter.\", response.json())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " url + \"/generate\",\n", " json={\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " \"lora_path\": [\"lora0\", \"lora1\"],\n", " },\n", ")\n", "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", "print(f\"Output from lora1: {response.json()[1]['text']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " url + \"/unload_lora_adapter\",\n", " json={\n", " \"lora_name\": \"lora0\",\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", " \"lora_name\": \"lora2\",\n", " \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n", " },\n", ")\n", "\n", "if response.status_code == 200:\n", " print(\"LoRA adapter loaded successfully.\", response.json())\n", "else:\n", " print(\"Failed to load LoRA adapter.\", response.json())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " url + \"/generate\",\n", " json={\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " \"lora_path\": [\"lora1\", \"lora2\"],\n", " },\n", ")\n", "print(f\"Output from lora1: {response.json()[0]['text']}\")\n", "print(f\"Output from lora2: {response.json()[1]['text']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Advanced: hosting adapters of different shapes\n", "\n", "In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n", "\n", "For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", "\n", "\n", "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", "# We are adding it here just to demonstrate usage.\n", "server_process, port = launch_server_cmd(\n", " f\"\"\"\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " --lora-paths lora0={lora0} \\\n", " --cuda-graph-max-bs 2 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", " --disable-radix-cache\n", " --max-lora-rank 64\n", " --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n", " \"\"\"\n", ")\n", "\n", "url = f\"http://127.0.0.1:{port}\"\n", "wait_for_server(url)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", " \"lora_name\": \"lora1\",\n", " \"lora_path\": lora1,\n", " },\n", ")\n", "\n", "if response.status_code == 200:\n", " print(\"LoRA adapter loaded successfully.\", response.json())\n", "else:\n", " print(\"Failed to load LoRA adapter.\", response.json())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "url = f\"http://127.0.0.1:{port}\"\n", "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", " \"AI is a field of computer science focused on\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", " \"lora_path\": [\"lora0\", \"lora1\"],\n", "}\n", "response = requests.post(\n", " url + \"/generate\",\n", " json=json_data,\n", ")\n", "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", "print(f\"Output from lora1: {response.json()[1]['text']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Future Works\n", "\n", "The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently radix attention is incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 2 }