Unverified Commit 66581596 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Enable cuda graph by default (#612)

parent 396a6924
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""
import argparse
import json
import time
import numpy as np
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()
if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")
def run_one_batch_size(bs):
url = f"{args.host}:{args.port}"
a = 20
max_new_tokens = args.max_tokens
a = 20
prompt = f"{a, }"
tic = time.time()
if args.backend == "srt":
if args.input_len:
inputs = {"input_ids": [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
]}
else:
inputs = {"text": [
f"{i, }" for i in range(bs)
]}
response = requests.post(
url + "/generate",
json={
"text": [prompt] * args.batch_size,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
**inputs,
},
)
elif args.backend == "lightllm":
......@@ -91,5 +89,41 @@ if __name__ == "__main__":
ret = response.json()
print(ret)
speed = args.batch_size * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
output_throughput = bs * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s")
with open("tmp_output.txt", "a") as fout:
res = {
"input_len": args.input_len,
"output_len": args.max_tokens,
"batch_size": bs,
"latency": latency,
"output_throughput": output_throughput
}
fout.write(json.dumps(res) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--input-len", type=int, default=None)
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()
if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")
for bs in args.batch_size:
run_one_batch_size(bs)
......@@ -30,7 +30,6 @@ import argparse
import dataclasses
import logging
import multiprocessing
import os
import time
......
......@@ -8,36 +8,40 @@ class GlobalConfig:
# 2: output final text after every run
self.verbosity = 0
# Default backend of the language
self.default_backend = None
# Output configs
# Runtime constants: Request dependency time due to network delay
self.request_dependency_delay = 0.02
self.wait_for_new_request_delay = 0.0006
# Runtime constants: New generation token ratio estimation
self.base_new_token_ratio = 0.4
self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001
self.new_token_ratio_recovery = 0.05
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192
# Runtime constants: Flashinfer
self.flashinfer_workspace_size = 192 * 1024 * 1024
# Output tokenization configs
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Optimization configs
# Interpreter optimization configs
self.eager_fill_image = False
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
self.enable_parallel_decoding = True
# Deprecated
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust"
# Request dependency time due to network delay
self.request_dependency_delay = 0.02
self.wait_for_new_request_delay = 0.0006
# New generation token ratio estimation
self.base_new_token_ratio = 0.4
self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001
self.new_token_ratio_recovery = 0.05
# The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192
global_config = GlobalConfig()
"""Run the model with cuda graph."""
import bisect
import torch
from vllm.distributed.parallel_state import graph_capture
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.infer_batch import (
Batch, ForwardMode, InputMetadata, init_flashinfer_args
)
class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture):
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None
# Common inputs
self.max_bs = max_batch_size_to_capture
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
# FlashInfer inputs
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
self.flashinfer_kv_indptr = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_indices = torch.zeros(
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
def can_run(self, batch_size):
return batch_size < self.max_bs
def capture(self, batch_size_list):
self.batch_size_list = batch_size_list
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in batch_size_list:
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs):
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
graph = torch.cuda.CUDAGraph()
stream = self.stream
# Common inputs
input_ids = self.input_ids[:bs]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs]
# FlashInfer inputs
if not _grouped_size_compiled_for_decode_kernels(
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
init_flashinfer_args(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
flashinfer_decode_wrapper,
)
# Run and capture
def run_once():
input_metadata = InputMetadata.create(
self.model_runner,
forward_mode=ForwardMode.DECODE,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=None,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_logprob=False,
top_logprobs_nums=0,
skip_flashinfer_init=True,
)
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
return self.model_runner.model.forward(
input_ids, input_metadata.positions, input_metadata
)
for _ in range(2):
run_once()
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.graph_memory_pool = graph.pool()
return graph, None, out, flashinfer_decode_wrapper
def replay(self, batch: Batch):
assert batch.out_cache_loc is not None
assert not batch.return_logprob
raw_bs = len(batch.reqs)
# Pad
index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:raw_bs] = batch.input_ids
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
self.seq_lens[:raw_bs] = batch.seq_lens
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# FlashInfer inputs
init_flashinfer_args(
ForwardMode.DECODE,
self.model_runner,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
None,
self.flashinfer_handlers[bs],
)
# Replay
self.graphs[bs].replay()
output = self.output_buffers[bs]
# Unpad
if bs == raw_bs:
return output
else:
output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
)
return output
\ No newline at end of file
......@@ -675,7 +675,11 @@ class Batch:
# TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
sampled_index = torch.multinomial(probs_sort, num_samples=1)
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
-1
)
......@@ -757,9 +761,11 @@ class InputMetadata:
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
skip_flashinfer_init=False,
):
if not model_runner.server_args.disable_flashinfer:
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens)
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
model_runner.flashinfer_decode_wrapper)
batch_size = len(req_pool_indices)
......@@ -826,7 +832,8 @@ class InputMetadata:
return ret
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens):
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
flashinfer_decode_wrapper):
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
......@@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)
if forward_mode == ForwardMode.DECODE:
model_runner.flashinfer_decode_wrapper.end_forward()
model_runner.flashinfer_decode_wrapper.begin_forward(
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
......
......@@ -15,6 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
......@@ -90,6 +91,9 @@ class ModelRunner:
self.init_cublas()
self.init_flash_infer()
# Capture cuda graphs
self.init_cuda_graphs()
def load_model(self):
logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. "
......@@ -203,29 +207,46 @@ class ModelRunner:
else:
use_tensor_cores = False
workspace_buffers = torch.empty(
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
self.flashinfer_workspace_buffers = torch.empty(
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
)
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[0], "NHD"
self.flashinfer_workspace_buffers[0], "NHD"
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[1], "NHD"
self.flashinfer_workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
)
def init_cuda_graphs(self):
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
self.cuda_graph_runner = None
return
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
self.cuda_graph_runner.capture(batch_size_list)
@torch.inference_mode()
def forward_extend(self, batch: Batch):
def forward_decode(self, batch: Batch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
forward_mode=ForwardMode.DECODE,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
out_cache_cont_start=batch.out_cache_cont_start,
out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
)
......@@ -234,17 +255,15 @@ class ModelRunner:
)
@torch.inference_mode()
def forward_decode(self, batch: Batch):
def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
forward_mode=ForwardMode.EXTEND,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
out_cache_cont_start=batch.out_cache_cont_start,
out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
)
......
......@@ -98,7 +98,7 @@ class ModelTpServer:
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
4096
8192
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
......@@ -314,11 +314,9 @@ class ModelTpServer:
self.forward_queue.append(req)
def get_new_fill_batch(self) -> Optional[Batch]:
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_running_requests
):
return None
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
if running_bs > self.max_running_requests:
return
# Compute matched prefix length
for req in self.forward_queue:
......@@ -394,6 +392,10 @@ class ModelTpServer:
new_batch_input_tokens += req.extend_input_len
else:
break
if running_bs + len(can_run_list) > self.max_running_requests:
break
if len(can_run_list) == 0:
return None
......
......@@ -38,7 +38,10 @@ class ReqToTokenPool:
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
self.size = size
# mem_state is the reference counter.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
self.total_ref_ct = 0
# [size, key/value, head_num, head_dim] for each layer
......@@ -47,6 +50,8 @@ class TokenToKVPool:
for _ in range(layer_num)
]
self.clear()
def get_key_buffer(self, layer_id):
return self.kv_data[layer_id][:, 0]
......@@ -101,3 +106,6 @@ class TokenToKVPool:
def clear(self):
self.mem_state.fill_(0)
self.total_ref_ct = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))
\ No newline at end of file
......@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
......
......@@ -29,7 +29,7 @@ class ServerArgs:
max_prefill_tokens: Optional[int] = None
max_running_requests: Optional[int] = None
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
schedule_conservativeness: float = 0.8
# Other runtime options
tp_size: int = 1
......@@ -68,13 +68,13 @@ class ServerArgs:
self.tokenizer_path = self.model_path
if self.mem_fraction_static is None:
if self.tp_size >= 8:
self.mem_fraction_static = 0.80
self.mem_fraction_static = 0.78
elif self.tp_size >= 4:
self.mem_fraction_static = 0.82
self.mem_fraction_static = 0.80
elif self.tp_size >= 2:
self.mem_fraction_static = 0.85
else:
self.mem_fraction_static = 0.90
self.mem_fraction_static = 0.88
if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment