"docs/source/vscode:/vscode.git/clone" did not exist on "bcf9923b32dacd8b1f06b156855b49b268b70dcc"
Unverified Commit 8c0cfca8 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Feat: support cuda graph for LoRA (#4115)


Co-authored-by: default avatarBeichen Ma <mabeichen12@gmail.com>
parent 2c3ea294
...@@ -19,7 +19,7 @@ def launch_server(args): ...@@ -19,7 +19,7 @@ def launch_server(args):
for i in range(NUM_LORAS): for i in range(NUM_LORAS):
lora_name = f"lora{i}" lora_name = f"lora{i}"
cmd += f"{lora_name}={lora_path} " cmd += f"{lora_name}={lora_path} "
cmd += f"--disable-radix --disable-cuda-graph " cmd += f"--disable-radix "
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} " cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
cmd += f"--max-running-requests {args.max_running_requests} " cmd += f"--max-running-requests {args.max_running_requests} "
cmd += f"--lora-backend {args.lora_backend} " cmd += f"--lora-backend {args.lora_backend} "
......
...@@ -77,7 +77,7 @@ ...@@ -77,7 +77,7 @@
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" --max-loras-per-batch 1 --lora-backend triton \\\n", " --max-loras-per-batch 1 --lora-backend triton \\\n",
" --disable-cuda-graph --disable-radix-cache\n", " --disable-radix-cache\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
...@@ -136,7 +136,7 @@ ...@@ -136,7 +136,7 @@
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-cuda-graph --disable-radix-cache\n", " --disable-radix-cache\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
...@@ -182,7 +182,7 @@ ...@@ -182,7 +182,7 @@
"source": [ "source": [
"## Future Works\n", "## Future Works\n",
"\n", "\n",
"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently Cuda graph and radix attention are not incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development." "The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Currently radix attention is incompatible with LoRA and must be manually disabled. Other features, including Unified Paging, Cutlass backend, and dynamic loading/unloadingm, are still under development."
] ]
} }
], ],
......
...@@ -160,7 +160,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -160,7 +160,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|----------|-------------|---------| |----------|-------------|---------|
| `lora_paths` | List of adapters to apply to your model. Each batch element uses the proper LoRA adapter. `cuda_graph` and `radix_attention` are not supported with this, so they must be disabled manually. See related [issues](https://github.com/sgl-project/sglang/issues/2929). | None | | `lora_paths` | List of adapters to apply to your model. Each batch element uses the proper LoRA adapter. `radix_attention` is not supported with this, so it must be disabled manually. See related [issues](https://github.com/sgl-project/sglang/issues/2929). | None |
| `max_loras_per_batch` | Maximum number of LoRAs allowed in a running batch, including the base model. | `8` | | `max_loras_per_batch` | Maximum number of LoRAs allowed in a running batch, including the base model. | `8` |
| `lora_backend` | Backend used to run GEMM kernels for LoRA modules. Can be `triton` or `flashinfer`. | `triton` | | `lora_backend` | Backend used to run GEMM kernels for LoRA modules. Can be `triton` or `flashinfer`. | `triton` |
......
...@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.set_lora = True self.set_lora = True
self.A_buffer_gate_up = A_buffer self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b: if self.lora_backend.fuse_stacked_lora_b:
# TODO: avoid using contiguous() in GPU.
# B_buffer_gate_up: (num_lora, 2 * output_dim, r) # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat( if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
(B_buffer[0], B_buffer[1]), dim=-2 self.B_buffer_gate_up = torch.empty(
).contiguous() (
B_buffer[0].shape[0],
2 * B_buffer[0].shape[1],
B_buffer[0].shape[2],
),
dtype=B_buffer[0].dtype,
device=B_buffer[0].device,
)
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
else: else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1]) self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
...@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__( def __init__(
self, self,
base_layer: QKVParallelLinear, base_layer: QKVParallelLinear,
lora_backend: BaseLoRABackend, lora_backend: BaseLoRABackend,
...@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat( if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 self.B_buffer_qkv = torch.empty(
).contiguous() (
B_buffer_q[0].shape[0],
output_dim_q + 2 * output_dim_kv,
B_buffer_q[0].shape[2],
),
dtype=B_buffer_q[0].dtype,
device=B_buffer_q[0].device,
)
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
B_buffer_kv[0]
)
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
B_buffer_kv[1]
)
# Offsets of q/k/v in output dimension # Offsets of q/k/v in output dimension
self.output_offset = torch.tensor( if not hasattr(self, "output_offset") or self.output_offset is None:
self.output_offset = torch.empty(
4, dtype=torch.int32, device=B_buffer_q.device
)
self.output_offset[:4] = torch.tensor(
[ [
0, 0,
output_dim_q, output_dim_q,
......
...@@ -72,6 +72,23 @@ class LoRAManager: ...@@ -72,6 +72,23 @@ class LoRAManager:
self.init_loras() self.init_loras()
self.init_lora_memory_pool() self.init_lora_memory_pool()
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=self.max_bs_in_cuda_graph,
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
),
max_len=0,
weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32
),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)
def init_loras(self): def init_loras(self):
# Config of each LoRA adapter # Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {} self.configs: Dict[str, LoRAConfig] = {}
...@@ -140,39 +157,73 @@ class LoRAManager: ...@@ -140,39 +157,73 @@ class LoRAManager:
if cur_uids == set([None]): if cur_uids == set([None]):
return return
# set up batch info shared by all lora moruldes # set up batch info shared by all lora modules
bs = forward_batch.batch_size bs = forward_batch.batch_size
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.empty( if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda" # Do in-place updates when CUDA graph is enabled. Note that
) # if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
scalings = torch.empty( # will also use these preallocated buffers, no matter whether
(self.max_loras_per_batch,), dtype=torch.float, device="cuda" # the batch can use CUDA graph or not.
) self.cuda_graph_batch_info.bs = bs
for i, lora_path in enumerate(forward_batch.lora_paths): if forward_batch.forward_mode.is_extend():
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) self.cuda_graph_batch_info.seg_lens[:bs].copy_(
lora = self.loras[lora_path] forward_batch.extend_seq_lens
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] )
scalings[weight_indices[i]] = lora.scaling else:
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
batch_info = LoRABatchInfo( torch.cumsum(
bs=bs, self.cuda_graph_batch_info.seg_lens[:bs],
seg_lens=seg_lens, dim=0,
seg_indptr=seg_indptr, out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
max_len=max_len, )
weight_indices=weight_indices, self.cuda_graph_batch_info.max_len = int(
lora_ranks=lora_ranks, torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
scalings=scalings, )
)
for i, lora_path in enumerate(forward_batch.lora_paths):
self.cuda_graph_batch_info.weight_indices[i] = (
self.memory_pool.get_buffer_id(lora_path)
)
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
batch_info = self.cuda_graph_batch_info
else:
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
lora_ranks = torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
)
scalings = torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
self.lora_backend.set_batch_info(batch_info) self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules # call set_lora_info for each lora modules
......
...@@ -220,6 +220,9 @@ class CudaGraphRunner: ...@@ -220,6 +220,9 @@ class CudaGraphRunner:
if self.enable_torch_compile: if self.enable_torch_compile:
set_torch_compile_config() set_torch_compile_config()
if self.model_runner.server_args.lora_paths is not None:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
# Graph inputs # Graph inputs
with torch.device("cuda"): with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
...@@ -403,6 +406,13 @@ class CudaGraphRunner: ...@@ -403,6 +406,13 @@ class CudaGraphRunner:
self.capture_hidden_mode = ( self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
) )
if self.model_runner.server_args.lora_paths is not None:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
# values if lora is enabled.
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
else:
lora_paths = None
forward_batch = ForwardBatch( forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode, forward_mode=self.capture_forward_mode,
...@@ -424,8 +434,12 @@ class CudaGraphRunner: ...@@ -424,8 +434,12 @@ class CudaGraphRunner:
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=self.capture_hidden_mode, capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
) )
if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
......
...@@ -1242,7 +1242,6 @@ class ServerArgs: ...@@ -1242,7 +1242,6 @@ class ServerArgs:
assert ( assert (
self.max_loras_per_batch > 0 self.max_loras_per_batch > 0
# FIXME # FIXME
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache) and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress" ), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
......
...@@ -24,7 +24,7 @@ from utils import ( ...@@ -24,7 +24,7 @@ from utils import (
DEFAULT_PROMPTS, DEFAULT_PROMPTS,
TORCH_DTYPES, TORCH_DTYPES,
LoRAModelCase, LoRAModelCase,
run_batch_lora_test, run_lora_test_one_by_one,
) )
from sglang.test.test_utils import CustomTestCase, is_in_ci from sglang.test.test_utils import CustomTestCase, is_in_ci
...@@ -42,7 +42,7 @@ class TestLoRABackend(CustomTestCase): ...@@ -42,7 +42,7 @@ class TestLoRABackend(CustomTestCase):
) )
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
for backend in BACKENDS: for backend in BACKENDS:
run_batch_lora_test( run_lora_test_one_by_one(
prompts, prompts,
model_case, model_case,
torch_dtype, torch_dtype,
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import unittest
from typing import List
from utils import (
ALL_OTHER_LORA_MODELS,
CI_LORA_MODELS,
DEFAULT_PROMPTS,
TORCH_DTYPES,
LoRAModelCase,
run_lora_test_by_batch,
run_lora_test_one_by_one,
)
from sglang.test.test_utils import CustomTestCase, is_in_ci
TEST_CUDA_GRAPH_PADDING_PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
"Computer science is the study of",
]
class TestLoRACudaGraph(CustomTestCase):
def _run_without_cuda_graph_on_model_cases(self, model_cases: List[LoRAModelCase]):
# Since we have already enabled CUDA graph by default in other lora tests,
# we only need to run lora tests without CUDA graph here.
for model_case in model_cases:
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
prompts = (
DEFAULT_PROMPTS
if not model_case.skip_long_prompt
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
)
for torch_dtype in TORCH_DTYPES:
run_lora_test_one_by_one(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
disable_cuda_graph=True,
test_tag="without_cuda_graph",
)
def _run_cuda_graph_padding_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases:
# Run a batch size of 3, which will not be captured by CUDA graph and need padding
prompts = TEST_CUDA_GRAPH_PADDING_PROMPTS
for torch_dtype in TORCH_DTYPES:
run_lora_test_by_batch(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
disable_cuda_graph=False,
test_tag="cuda_graph_padding",
)
def test_ci_lora_models(self):
self._run_without_cuda_graph_on_model_cases(CI_LORA_MODELS)
self._run_cuda_graph_padding_on_model_cases(CI_LORA_MODELS)
def test_all_lora_models(self):
if is_in_ci():
return
# Retain ONLY_RUN check here
filtered_models = []
for model_case in ALL_OTHER_LORA_MODELS:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
filtered_models.append(model_case)
self._run_without_cuda_graph_on_model_cases(filtered_models)
self._run_cuda_graph_padding_on_model_cases(filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
...@@ -23,7 +23,7 @@ from utils import ( ...@@ -23,7 +23,7 @@ from utils import (
DEFAULT_PROMPTS, DEFAULT_PROMPTS,
TORCH_DTYPES, TORCH_DTYPES,
LoRAModelCase, LoRAModelCase,
run_batch_lora_test, run_lora_test_one_by_one,
) )
from sglang.test.test_utils import CustomTestCase, is_in_ci from sglang.test.test_utils import CustomTestCase, is_in_ci
...@@ -43,7 +43,7 @@ class TestLoRATP(CustomTestCase): ...@@ -43,7 +43,7 @@ class TestLoRATP(CustomTestCase):
for tp_size in tp_list: for tp_size in tp_list:
model_case.tp_size = tp_size model_case.tp_size = tp_size
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
run_batch_lora_test( run_lora_test_one_by_one(
prompts, prompts,
model_case, model_case,
torch_dtype, torch_dtype,
......
...@@ -22,7 +22,7 @@ from utils import ( ...@@ -22,7 +22,7 @@ from utils import (
TORCH_DTYPES, TORCH_DTYPES,
LoRAAdaptor, LoRAAdaptor,
LoRAModelCase, LoRAModelCase,
run_batch_lora_test, run_lora_test_one_by_one,
) )
from sglang.test.test_utils import CustomTestCase, is_in_ci from sglang.test.test_utils import CustomTestCase, is_in_ci
...@@ -89,7 +89,7 @@ class TestMultiLoRABackend(CustomTestCase): ...@@ -89,7 +89,7 @@ class TestMultiLoRABackend(CustomTestCase):
) )
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
for backend in BACKENDS: for backend in BACKENDS:
run_batch_lora_test( run_lora_test_one_by_one(
batch_prompts, batch_prompts,
model_case, model_case,
torch_dtype, torch_dtype,
......
...@@ -94,19 +94,20 @@ ALL_OTHER_LORA_MODELS = [ ...@@ -94,19 +94,20 @@ ALL_OTHER_LORA_MODELS = [
] ]
def run_batch_lora_test( def run_lora_test_one_by_one(
prompts: List[str], prompts: List[str],
model_case: LoRAModelCase, model_case: LoRAModelCase,
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
max_new_tokens: int, max_new_tokens: int,
backend: str, backend: str,
disable_cuda_graph: bool = True, disable_cuda_graph: bool = False,
disable_radix_cache: bool = True, disable_radix_cache: bool = True,
mem_fraction_static: float = 0.88, mem_fraction_static: float = 0.88,
test_tag: str = "", test_tag: str = "",
): ):
""" """
Run Lora test for a forward batch. Input a batch of prompts, and run lora tests one by one with several generate requests
(each request will have bs=1).
For prompt0, prompt1, ..., promptN, For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case, we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with and without LoRA. We will then compare the outputs of HF and SRT with and without LoRA.
...@@ -119,7 +120,7 @@ def run_batch_lora_test( ...@@ -119,7 +120,7 @@ def run_batch_lora_test(
torch_dtype (torch.dtype): The torch dtype to use. torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate. max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use. backend (str): The lora backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to True. disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True. disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88. mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
test_tag (str, optional): The tag to use for the test. Defaults to "". test_tag (str, optional): The tag to use for the test. Defaults to "".
...@@ -237,3 +238,112 @@ def run_batch_lora_test( ...@@ -237,3 +238,112 @@ def run_batch_lora_test(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'" f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'"
) )
def run_lora_test_by_batch(
prompts: List[str],
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = True,
mem_fraction_static: float = 0.88,
test_tag: str = "",
):
"""
Run lora tests as a batch.
For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with LoRA.
If number of prompts is larger than number of adaptors,
the prompt i will use adaptor i % (number of adaptors).
Args:
prompts (List[str]): The batch of prompts to test.
model_case (LoRAModelCase): The model case to test.
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to True.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
test_tag (str, optional): The tag to use for the test. Defaults to "".
"""
base_path = model_case.base
# Create used adaptors for each prompt in batch
i, adaptors = 0, []
for _ in range(len(prompts)):
adaptors.append(model_case.adaptors[i])
i = (i + 1) % len(model_case.adaptors)
adaptor_names = [adaptor.name for adaptor in adaptors]
print(
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
lora_paths=[adaptor.name for adaptor in model_case.adaptors],
max_loras_per_batch=model_case.max_loras_per_batch,
lora_backend=backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_no_lora_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
)
for i in range(len(prompts)):
srt_output_str = srt_outputs.output_strs[i].strip()
hf_output_str = hf_outputs.output_strs[i].strip()
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
print("ROUGE-L score:", rouge_score)
print("SRT output:", srt_output_str)
print("HF output:", hf_output_str)
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
" "
), (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i].strip(" "),
)
assert srt_no_lora_outputs.output_strs[i].strip(
" "
) == hf_no_lora_outputs.output_strs[i].strip(" "), (
srt_no_lora_outputs.output_strs[i].strip(" "),
hf_no_lora_outputs.output_strs[i].strip(" "),
)
...@@ -80,6 +80,7 @@ suites = { ...@@ -80,6 +80,7 @@ suites = {
TestFile("test_vlm_accuracy.py", 60), TestFile("test_vlm_accuracy.py", 60),
TestFile("test_vision_openai_server.py", 637), TestFile("test_vision_openai_server.py", 637),
TestFile("test_w8a8_quantization.py", 46), TestFile("test_w8a8_quantization.py", 46),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("models/lora/test_lora_tp.py", 116), TestFile("models/lora/test_lora_tp.py", 116),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment