Commit 41d1f677 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Fix flush cache (#627)

parent 56f5fc4a
# Benchmark Latency and Throughput # Benchmark Latency and Throughput
## SGLang ## SGLang
### Launch server ### Launch a server
``` ```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
``` ```
...@@ -33,6 +32,11 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r ...@@ -33,6 +32,11 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 python3 bench_serving.py --backend srt --port 30000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10
``` ```
### Profile with Nsight
1. To profile a single batch, use `nsys profile --cuda-graph-trace=node python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 64 --input-len 512`
2. To profile a server, use `nsys profile --cuda-graph-trace=node python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B`.
## Other baselines ## Other baselines
### vLLM ### vLLM
...@@ -64,4 +68,4 @@ python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat ...@@ -64,4 +68,4 @@ python -m lightllm.server.api_server --model_dir ~/model_weights/Llama-2-7b-chat
``` ```
python3 bench_serving.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10 python3 bench_serving.py --backend lightllm --port 22000 --tokenizer meta-llama/Llama-2-7b-chat-hf --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10 --request-rate 10
``` ```
\ No newline at end of file
...@@ -102,8 +102,8 @@ def run_one_batch_size(bs): ...@@ -102,8 +102,8 @@ def run_one_batch_size(bs):
output_throughput = bs * max_new_tokens / latency output_throughput = bs * max_new_tokens / latency
overall_throughput = bs * (input_len + output_len) / latency overall_throughput = bs * (input_len + output_len) / latency
print(f"latency: {latency:.2f} s") print(f"latency: {latency:.2f} s")
print(f"decode throughput: {output_throughput:.2f} token/s") print(f"output throughput: {output_throughput:.2f} token/s")
print(f"overall throughput: {overall_throughput:.2f} token/s") print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
with open("results.jsonl", "a") as fout: with open("results.jsonl", "a") as fout:
res = { res = {
......
...@@ -284,23 +284,26 @@ def main(server_args, bench_args): ...@@ -284,23 +284,26 @@ def main(server_args, bench_args):
else: else:
work_func = latency_test work_func = latency_test
workers = [] if server_args.tp_size == 1:
for tp_rank in range(server_args.tp_size): work_func(server_args, bench_args, 0)
proc = multiprocessing.Process( else:
target=work_func, workers = []
args=( for tp_rank in range(server_args.tp_size):
server_args, proc = multiprocessing.Process(
bench_args, target=work_func,
tp_rank, args=(
), server_args,
) bench_args,
proc.start() tp_rank,
workers.append(proc) ),
)
for proc in workers: proc.start()
proc.join() workers.append(proc)
proc.terminate() for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -96,6 +96,7 @@ class ControllerSingle: ...@@ -96,6 +96,7 @@ class ControllerSingle:
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict): def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.tp_procs = []
# Init communication # Init communication
context = zmq.Context(2) context = zmq.Context(2)
......
...@@ -98,6 +98,8 @@ class TokenToKVPool: ...@@ -98,6 +98,8 @@ class TokenToKVPool:
self.can_use_mem_size += len(free_index) self.can_use_mem_size += len(free_index)
def clear(self): def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.mem_state.fill_(True) self.mem_state.fill_(True)
self.can_use_mem_size = self.size self.can_use_mem_size = self.size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment