lora.ipynb 20.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LoRA Serving"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SGLang enables the use of [LoRA adapters](https://arxiv.org/abs/2106.09685) with a base model. By incorporating techniques from [S-LoRA](https://arxiv.org/pdf/2311.03285) and [Punica](https://arxiv.org/pdf/2310.18547), SGLang can efficiently support multiple LoRA adapters for different sequences within a single batch of inputs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Arguments for LoRA Serving"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following server arguments are relevant for multi-LoRA serving:\n",
    "\n",
30
31
    "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
    "\n",
32
    "* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n",
33
34
35
    "\n",
    "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
    "\n",
36
37
    "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
    "\n",
38
39
    "* `lora_eviction_policy`: LoRA adapter eviction policy when GPU memory pool is full. `lru`: Least Recently Used (default, better cache efficiency). `fifo`: First-In-First-Out.\n",
    "\n",
40
    "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
41
    "\n",
42
43
    "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
    "\n",
44
    "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n",
45
    "\n",
46
47
    "* `--max-lora-chunk-size`: Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance. Please tune this value based on your hardware and workload as needed. Defaults to 16.\n",
    "\n",
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
    "\n",
    "From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage\n",
    "\n",
    "### Serving Single Adaptor"
   ]
  },
62
63
64
65
66
67
68
69
70
71
72
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** SGLang supports LoRA adapters through two APIs:\n",
    "\n",
    "1. **OpenAI-Compatible API** (`/v1/chat/completions`, `/v1/completions`): Use the `model:adapter-name` syntax. See [OpenAI API with LoRA](../basic_usage/openai_api_completions.ipynb#Using-LoRA-Adapters) for examples.\n",
    "\n",
    "2. **Native API** (`/generate`): Pass `lora_path` in the request body (shown below)."
   ]
  },
73
74
75
76
77
78
79
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
83
    "import requests\n",
    "\n",
    "from sglang.test.doc_patch import launch_server_cmd\n",
    "from sglang.utils import wait_for_server, terminate_process"
84
85
86
87
88
89
90
91
92
93
94
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server_process, port = launch_server_cmd(\n",
    "    \"\"\"\n",
    "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
95
    "    --enable-lora \\\n",
96
    "    --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
97
    "    --max-loras-per-batch 1 \\\n",
98
    "    --log-level warning \\\n",
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(f\"http://localhost:{port}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = f\"http://127.0.0.1:{port}\"\n",
    "json_data = {\n",
    "    \"text\": [\n",
    "        \"List 3 countries and their capitals.\",\n",
115
    "        \"List 3 countries and their capitals.\",\n",
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    "    ],\n",
    "    \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
    "    # The first input uses lora0, and the second input uses the base model\n",
    "    \"lora_path\": [\"lora0\", None],\n",
    "}\n",
    "response = requests.post(\n",
    "    url + \"/generate\",\n",
    "    json=json_data,\n",
    ")\n",
    "print(f\"Output 0: {response.json()[0]['text']}\")\n",
    "print(f\"Output 1: {response.json()[1]['text']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Serving Multiple Adaptors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server_process, port = launch_server_cmd(\n",
    "    \"\"\"\n",
    "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
154
    "    --enable-lora \\\n",
155
156
    "    --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
    "    lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
157
    "    --max-loras-per-batch 2 \\\n",
158
    "    --log-level warning \\\n",
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(f\"http://localhost:{port}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = f\"http://127.0.0.1:{port}\"\n",
    "json_data = {\n",
    "    \"text\": [\n",
    "        \"List 3 countries and their capitals.\",\n",
175
    "        \"List 3 countries and their capitals.\",\n",
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    "    ],\n",
    "    \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
    "    # The first input uses lora0, and the second input uses lora1\n",
    "    \"lora_path\": [\"lora0\", \"lora1\"],\n",
    "}\n",
    "response = requests.post(\n",
    "    url + \"/generate\",\n",
    "    json=json_data,\n",
    ")\n",
    "print(f\"Output 0: {response.json()[0]['text']}\")\n",
    "print(f\"Output 1: {response.json()[1]['text']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
198
199
200
201
202
203
204
205
206
207
208
209
210
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dynamic LoRA loading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
    "\n",
211
    "When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
212
213
214
215
216
217
218
219
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
220
221
222
223
224
225
226
    "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\"  # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
    "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"  # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
    "lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\"  # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
    "\n",
    "\n",
    "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
    "# We are adding it here just to demonstrate usage.\n",
227
228
229
    "server_process, port = launch_server_cmd(\n",
    "    \"\"\"\n",
    "    python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
230
    "    --enable-lora \\\n",
231
    "    --cuda-graph-max-bs 2 \\\n",
232
    "    --max-loras-per-batch 2 \\\n",
233
234
    "    --max-lora-rank 256\n",
    "    --lora-target-modules all\n",
235
    "    --log-level warning\n",
236
237
238
239
240
241
242
    "    \"\"\"\n",
    ")\n",
    "\n",
    "url = f\"http://127.0.0.1:{port}\"\n",
    "wait_for_server(url)"
   ]
  },
243
244
245
246
247
248
249
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load adapter lora0"
   ]
  },
250
251
252
253
254
255
256
257
258
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = requests.post(\n",
    "    url + \"/load_lora_adapter\",\n",
    "    json={\n",
259
260
    "        \"lora_name\": \"lora0\",\n",
    "        \"lora_path\": lora0,\n",
261
262
263
264
265
266
267
268
269
270
    "    },\n",
    ")\n",
    "\n",
    "if response.status_code == 200:\n",
    "    print(\"LoRA adapter loaded successfully.\", response.json())\n",
    "else:\n",
    "    print(\"Failed to load LoRA adapter.\", response.json())"
   ]
  },
  {
271
   "cell_type": "markdown",
272
273
   "metadata": {},
   "source": [
274
    "Load adapter lora1:"
275
276
277
278
279
280
281
282
283
284
285
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = requests.post(\n",
    "    url + \"/load_lora_adapter\",\n",
    "    json={\n",
286
287
    "        \"lora_name\": \"lora1\",\n",
    "        \"lora_path\": lora1,\n",
288
289
290
291
292
293
294
295
296
297
    "    },\n",
    ")\n",
    "\n",
    "if response.status_code == 200:\n",
    "    print(\"LoRA adapter loaded successfully.\", response.json())\n",
    "else:\n",
    "    print(\"Failed to load LoRA adapter.\", response.json())"
   ]
  },
  {
298
   "cell_type": "markdown",
299
300
   "metadata": {},
   "source": [
301
    "Check inference output:"
302
303
304
305
306
307
308
309
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    "url = f\"http://127.0.0.1:{port}\"\n",
    "json_data = {\n",
    "    \"text\": [\n",
    "        \"List 3 countries and their capitals.\",\n",
    "        \"List 3 countries and their capitals.\",\n",
    "    ],\n",
    "    \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
    "    # The first input uses lora0, and the second input uses lora1\n",
    "    \"lora_path\": [\"lora0\", \"lora1\"],\n",
    "}\n",
    "response = requests.post(\n",
    "    url + \"/generate\",\n",
    "    json=json_data,\n",
    ")\n",
    "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
    "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
326
327
328
329
330
331
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
332
    "Unload lora0 and replace it with a different adapter:"
333
334
335
336
337
338
339
340
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
341
342
343
344
345
    "response = requests.post(\n",
    "    url + \"/unload_lora_adapter\",\n",
    "    json={\n",
    "        \"lora_name\": \"lora0\",\n",
    "    },\n",
346
347
348
349
350
    ")\n",
    "\n",
    "response = requests.post(\n",
    "    url + \"/load_lora_adapter\",\n",
    "    json={\n",
351
352
    "        \"lora_name\": \"lora0\",\n",
    "        \"lora_path\": lora0_new,\n",
353
354
355
356
357
358
359
360
361
    "    },\n",
    ")\n",
    "\n",
    "if response.status_code == 200:\n",
    "    print(\"LoRA adapter loaded successfully.\", response.json())\n",
    "else:\n",
    "    print(\"Failed to load LoRA adapter.\", response.json())"
   ]
  },
362
363
364
365
366
367
368
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check output again:"
   ]
  },
369
370
371
372
373
374
375
376
377
378
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = f\"http://127.0.0.1:{port}\"\n",
    "json_data = {\n",
    "    \"text\": [\n",
    "        \"List 3 countries and their capitals.\",\n",
379
    "        \"List 3 countries and their capitals.\",\n",
380
381
382
383
384
385
386
387
388
    "    ],\n",
    "    \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
    "    # The first input uses lora0, and the second input uses lora1\n",
    "    \"lora_path\": [\"lora0\", \"lora1\"],\n",
    "}\n",
    "response = requests.post(\n",
    "    url + \"/generate\",\n",
    "    json=json_data,\n",
    ")\n",
389
390
    "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
    "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
391
392
   ]
  },
393
394
395
396
397
398
399
400
401
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### OpenAI-compatible API usage\n",
    "\n",
    "You can use LoRA adapters via the OpenAI-compatible APIs by specifying the adapter in the `model` field using the `base-model:adapter-name` syntax (for example, `qwen/qwen2.5-0.5b-instruct:adapter_a`). For more details and examples, see the “Using LoRA Adapters” section in the OpenAI API documentation: [openai_api_completions.ipynb](../basic_usage/openai_api_completions.ipynb).\n"
   ]
  },
402
403
404
405
406
407
408
409
410
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LoRA GPU Pinning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another advanced option is to specify adapters as `pinned` during loading. When an adapter is pinned, it is permanently assigned to one of the available GPU pool slots (as configured by `--max-loras-per-batch`) and will not be evicted from GPU memory during runtime. Instead, it remains resident until it is explicitly unloaded.\n",
    "\n",
    "This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
    "\n",
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    "In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server_process, port = launch_server_cmd(\n",
    "    \"\"\"\n",
    "    python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
    "    --enable-lora \\\n",
    "    --cuda-graph-max-bs 8 \\\n",
440
    "    --max-loras-per-batch 3 \\\n",
441
442
443
444
445
446
    "    --max-lora-rank 256 \\\n",
    "    --lora-target-modules all \\\n",
    "    --lora-paths \\\n",
    "        {\"lora_name\":\"lora0\",\"lora_path\":\"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\"pinned\":true} \\\n",
    "        {\"lora_name\":\"lora1\",\"lora_path\":\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"} \\\n",
    "        lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\n",
447
    "    --log-level warning\n",
448
449
450
451
452
453
454
455
456
457
458
459
460
    "    \"\"\"\n",
    ")\n",
    "\n",
    "\n",
    "url = f\"http://127.0.0.1:{port}\"\n",
    "wait_for_server(url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:"
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = requests.post(\n",
    "    url + \"/unload_lora_adapter\",\n",
    "    json={\n",
    "        \"lora_name\": \"lora1\",\n",
    "    },\n",
    ")\n",
    "\n",
    "response = requests.post(\n",
    "    url + \"/load_lora_adapter\",\n",
    "    json={\n",
    "        \"lora_name\": \"lora1\",\n",
480
    "        \"lora_path\": \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\",\n",
481
482
483
484
485
486
487
488
489
    "        \"pinned\": True,  # Pin the adapter to GPU\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
490
    "Verify that the results are expected:"
491
492
493
494
495
496
497
498
499
500
501
502
503
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = f\"http://127.0.0.1:{port}\"\n",
    "json_data = {\n",
    "    \"text\": [\n",
    "        \"List 3 countries and their capitals.\",\n",
    "        \"List 3 countries and their capitals.\",\n",
504
    "        \"List 3 countries and their capitals.\",\n",
505
506
507
    "    ],\n",
    "    \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
    "    # The first input uses lora0, and the second input uses lora1\n",
508
    "    \"lora_path\": [\"lora0\", \"lora1\", \"lora2\"],\n",
509
510
511
512
513
    "}\n",
    "response = requests.post(\n",
    "    url + \"/generate\",\n",
    "    json=json_data,\n",
    ")\n",
514
515
516
    "print(f\"Output from lora0 (pinned): \\n{response.json()[0]['text']}\\n\")\n",
    "print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")\n",
    "print(f\"Output from lora2 (not pinned): \\n{response.json()[2]['text']}\\n\")"
517
518
   ]
  },
519
520
521
522
523
524
525
526
527
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choosing LoRA Backend\n",
    "\n",
    "SGLang supports two LoRA backends that you can choose from using the `--lora-backend` argument:\n",
    "\n",
    "- `triton`: Default basic Triton-based backend.\n",
    "- `csgmv`: Chunked SGMV backend optimized for high concurrency scenarios.\n",
    "\n",
    "The `csgmv` backend was recently introduced to improve performance especially at high-concurrency scenarios. Our benchmark shows that it achieves 20% to 80% latency improvements over the basic triton backend.\n",
    "Currently it is at preview phase, we expect to make it our the default LoRA backend in future release. Before that, you can adopt it by manually setting the `--lora-backend` server config."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server_process, port = launch_server_cmd(\n",
    "    \"\"\"\n",
    "    python3 -m sglang.launch_server \\\n",
    "    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
    "    --enable-lora \\\n",
    "    --lora-backend csgmv \\\n",
    "    --max-loras-per-batch 16 \\\n",
    "    --lora-paths lora1=path/to/lora1 lora2=path/to/lora2\n",
    "    \"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
570
571
572
573
574
575
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Future Works\n",
    "\n",
576
    "The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Other features, including Embedding Layer, Unified Paging, Cutlass backend are still under development."
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}