"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "622c46397c1aeae92d8bf970567910b9f2acdb58"
Unverified Commit a02071a1 authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[Bench] feat: mooncake trace integration (#9839)


Signed-off-by: default avatarXuchun Shang <xuchun.shang@linux.alibaba.com>
Signed-off-by: default avatarTeng Ma <sima.mt@alibaba-inc.com>
Co-authored-by: default avatarXuchun Shang <xuchun.shang@linux.alibaba.com>
parent 45b3a6a2
...@@ -305,6 +305,21 @@ python3 -m sglang.bench_serving \ ...@@ -305,6 +305,21 @@ python3 -m sglang.bench_serving \
--disable-ignore-eos --disable-ignore-eos
``` ```
9) Evaluating large-scale KVCache sharing with mooncake trace (sglang only):
```bash
python3 -m sglang.bench_serving \
--backend sglang \
--host 127.0.0.1 --port 30000 \
--model mode-name \
--dataset-name mooncake \
--mooncake-slowdown-factor 1.0 \
--mooncake-num-rounds 1000 \
--mooncake-workload conversation|mooncake|agent|synthetic
--use-trace-timestamps true \
--random-output-len 256
```
### Troubleshooting ### Troubleshooting
- All requests failed: verify `--backend`, server URL/port, `--model`, and authentication. Check warmup errors printed by the script. - All requests failed: verify `--backend`, server URL/port, `--model`, and authentication. Check warmup errors printed by the script.
......
...@@ -75,6 +75,7 @@ class RequestFuncInput: ...@@ -75,6 +75,7 @@ class RequestFuncInput:
lora_name: str lora_name: str
image_data: Optional[List[str]] image_data: Optional[List[str]]
extra_request_body: Dict[str, Any] extra_request_body: Dict[str, Any]
timestamp: Optional[float] = None
@dataclass @dataclass
...@@ -696,6 +697,22 @@ def get_dataset(args, tokenizer): ...@@ -696,6 +697,22 @@ def get_dataset(args, tokenizer):
apply_chat_template=args.apply_chat_template, apply_chat_template=args.apply_chat_template,
random_sample=True, random_sample=True,
) )
elif args.dataset_name == "mooncake":
# For mooncake, we don't generate the prompts here.
# We just load the raw trace data. The async generator will handle the rest.
if not args.dataset_path:
local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
else:
local_path = args.dataset_path
if not os.path.exists(local_path):
download_and_cache_file(MOONCAKE_DATASET_URL[args.mooncake_workload], local_path)
with open(local_path, "r") as f:
all_requests_data = [json.loads(line) for line in f if line.strip()]
# Limit the number of requests based on --num-prompts
input_requests = all_requests_data[: args.num_prompts]
else: else:
raise ValueError(f"Unknown dataset: {args.dataset_name}") raise ValueError(f"Unknown dataset: {args.dataset_name}")
return input_requests return input_requests
...@@ -750,6 +767,12 @@ class BenchmarkMetrics: ...@@ -750,6 +767,12 @@ class BenchmarkMetrics:
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
MOONCAKE_DATASET_URL = {
"mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
"conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
"synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
"toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
}
def download_and_cache_file(url: str, filename: Optional[str] = None): def download_and_cache_file(url: str, filename: Optional[str] = None):
...@@ -808,6 +831,80 @@ class DatasetRow: ...@@ -808,6 +831,80 @@ class DatasetRow:
prompt_len: int prompt_len: int
output_len: int output_len: int
image_data: Optional[List[str]] = None image_data: Optional[List[str]] = None
timestamp: Optional[float] = None
async def get_mooncake_request_over_time(
input_requests: List[Dict],
tokenizer: PreTrainedTokenizerBase,
slowdown_factor: float,
num_rounds: int,
) -> AsyncGenerator[DatasetRow, None]:
"""
An async generator that yields requests based on the timestamps in the Mooncake trace file,
with support for multi-round sessions.
"""
if not input_requests:
return
input_requests.sort(key=lambda r: r["timestamp"])
start_time = time.perf_counter()
trace_start_time_ms = input_requests[0]["timestamp"]
for record in input_requests:
# Calculate when this entire session should start
relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
target_arrival_time_s = relative_arrival_time_s * slowdown_factor
current_elapsed_time_s = time.perf_counter() - start_time
sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
if sleep_duration_s > 0:
await asyncio.sleep(sleep_duration_s)
# Once the session starts, generate all rounds for it as a burst
# This simulates a user engaging in a multi-turn conversation
# Base user query constructed from hash_ids
user_query_base = ""
hash_ids = record.get("hash_ids", [])
for hash_id in hash_ids:
user_query_base += f"{hash_id}" + " ".join(
["hi"] * 128
) # Shorter for multi-round
user_query_base += "Tell me a story based on this context."
output_len_per_round = record.get("output_length", 256)
chat_history = []
for i in range(num_rounds):
# Add user query for the current round
chat_history.append(
{"role": "user", "content": f"Round {i+1}: {user_query_base}"}
)
# Form the full prompt from history
try:
full_prompt_text = tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
except Exception:
full_prompt_text = "\n".join(
[f"{msg['role']}: {msg['content']}" for msg in chat_history]
)
prompt_len = len(tokenizer.encode(full_prompt_text))
yield DatasetRow(
prompt=full_prompt_text,
prompt_len=prompt_len,
output_len=output_len_per_round,
)
# Add a placeholder assistant response for the next round's context
# We use a placeholder because we don't know the real response
placeholder_response = " ".join(["story"] * output_len_per_round)
chat_history.append({"role": "assistant", "content": placeholder_response})
def sample_mmmu_requests( def sample_mmmu_requests(
...@@ -1359,19 +1456,41 @@ def sample_generated_shared_prefix_requests( ...@@ -1359,19 +1456,41 @@ def sample_generated_shared_prefix_requests(
async def get_request( async def get_request(
input_requests: List[DatasetRow], input_requests: List[DatasetRow],
request_rate: float, request_rate: float,
use_trace_timestamps: bool = False,
slowdown_factor: float = 1.0,
) -> AsyncGenerator[DatasetRow, None]: ) -> AsyncGenerator[DatasetRow, None]:
input_requests = iter(input_requests) if use_trace_timestamps:
for request in input_requests: print(
yield request f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
)
# Sort requests by timestamp for correct replay
input_requests.sort(key=lambda r: r.timestamp)
if request_rate == float("inf"): start_time = time.perf_counter()
# If the request rate is infinity, then we don't need to wait. trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
continue
for request in input_requests:
trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
target_arrival_time = start_time + (trace_time_s * slowdown_factor)
sleep_duration = target_arrival_time - time.perf_counter()
if sleep_duration > 0:
await asyncio.sleep(sleep_duration)
yield request
else:
input_requests_iter = iter(input_requests)
for request in input_requests_iter:
yield request
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution. # Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate) interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval. # The next request will be sent after the interval.
await asyncio.sleep(interval) await asyncio.sleep(interval)
def calculate_metrics( def calculate_metrics(
...@@ -1397,7 +1516,7 @@ def calculate_metrics( ...@@ -1397,7 +1516,7 @@ def calculate_metrics(
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False) tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
) )
retokenized_output_lens.append(retokenized_output_len) retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i].prompt_len total_input += outputs[i].prompt_len
if output_len > 1: if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl itls += outputs[i].itl
...@@ -1469,6 +1588,9 @@ async def benchmark( ...@@ -1469,6 +1588,9 @@ async def benchmark(
pd_separated: bool = False, pd_separated: bool = False,
flush_cache: bool = False, flush_cache: bool = False,
warmup_requests: int = 1, warmup_requests: int = 1,
use_trace_timestamps: bool = False,
mooncake_slowdown_factor=1.0,
mooncake_num_rounds=1,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -1488,8 +1610,32 @@ async def benchmark( ...@@ -1488,8 +1610,32 @@ async def benchmark(
# Warmup # Warmup
print(f"Starting warmup with {warmup_requests} sequences...") print(f"Starting warmup with {warmup_requests} sequences...")
# Use the first request for all warmup iterations # Handle the data structure difference for the warmup request
test_request = input_requests[0] if args.dataset_name == "mooncake":
# For mooncake, input_requests is a list of dicts.
# We need to build a temporary DatasetRow for the warmup phase.
warmup_record = input_requests[0]
# Build prompt from hash_ids, just like in the async generator
hash_ids = warmup_record.get("hash_ids", [])
prompt_text = ""
for hash_id in hash_ids:
prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
prompt_text += "Can you tell me a detailed story in 1000 words?"
output_len = warmup_record.get("output_length", 32)
prompt_len = len(tokenizer.encode(prompt_text))
# Create a temporary DatasetRow object for warmup
test_request = DatasetRow(
prompt=prompt_text,
prompt_len=prompt_len,
output_len=output_len,
image_data=None, # Mooncake doesn't have image data
)
else:
# For all other datasets, input_requests is a list of DatasetRow objects
test_request = input_requests[0]
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
lora_name = lora_names[0] lora_name = lora_names[0]
...@@ -1543,12 +1689,26 @@ async def benchmark( ...@@ -1543,12 +1689,26 @@ async def benchmark(
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
# Run all requests # Run all requests
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): pbar_total = len(input_requests)
if (
backend == "sglang" and args.dataset_name == "mooncake"
): # Assuming mooncake is mainly for sglang or similar backends
print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
request_generator = get_mooncake_request_over_time(
input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
)
print(
f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
)
pbar_total *= args.mooncake_num_rounds
else:
request_generator = get_request(input_requests, request_rate)
pbar = None if disable_tqdm else tqdm(total=pbar_total)
async for request in request_generator:
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1) idx = random.randint(0, len(lora_names) - 1)
lora_name = lora_names[idx] lora_name = lora_names[idx]
...@@ -1564,6 +1724,7 @@ async def benchmark( ...@@ -1564,6 +1724,7 @@ async def benchmark(
lora_name=lora_name, lora_name=lora_name,
image_data=request.image_data, image_data=request.image_data,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
timestamp=request.timestamp,
) )
tasks.append( tasks.append(
...@@ -1609,7 +1770,11 @@ async def benchmark( ...@@ -1609,7 +1770,11 @@ async def benchmark(
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Backend:", backend)) print("{:<40} {:<10}".format("Backend:", backend))
print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) print(
"{:<40} {:<10}".format(
"Traffic request rate:", "trace" if use_trace_timestamps else request_rate
)
)
print( print(
"{:<40} {:<10}".format( "{:<40} {:<10}".format(
"Max request concurrency:", "Max request concurrency:",
...@@ -1678,7 +1843,7 @@ async def benchmark( ...@@ -1678,7 +1843,7 @@ async def benchmark(
# Arguments # Arguments
"backend": args.backend, "backend": args.backend,
"dataset_name": args.dataset_name, "dataset_name": args.dataset_name,
"request_rate": request_rate, "request_rate": "trace" if use_trace_timestamps else request_rate,
"max_concurrency": max_concurrency, "max_concurrency": max_concurrency,
"sharegpt_output_len": args.sharegpt_output_len, "sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len, "random_input_len": args.random_input_len,
...@@ -1731,7 +1896,9 @@ async def benchmark( ...@@ -1731,7 +1896,9 @@ async def benchmark(
elif args.dataset_name.startswith("random"): elif args.dataset_name.startswith("random"):
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
else: else:
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" output_file_name = (
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
)
result_details = { result_details = {
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
...@@ -1786,6 +1953,17 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1786,6 +1953,17 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "tokenize_prompt"): if not hasattr(args, "tokenize_prompt"):
args.tokenize_prompt = False args.tokenize_prompt = False
if not hasattr(args, "use_trace_timestamps"):
args.use_trace_timestamps = False
if not hasattr(args, "mooncake_slowdown_factor"):
args.mooncake_slowdown_factor = 1.0
if not hasattr(args, "mooncake_slowdown_factor"):
args.mooncake_slowdown_factor = 1.0
if not hasattr(args, "mooncake_num_rounds"):
args.mooncake_num_rounds = 1
print(f"benchmark_args={args}") print(f"benchmark_args={args}")
# Set global environments # Set global environments
...@@ -1919,6 +2097,9 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1919,6 +2097,9 @@ def run_benchmark(args_: argparse.Namespace):
pd_separated=args.pd_separated, pd_separated=args.pd_separated,
flush_cache=args.flush_cache, flush_cache=args.flush_cache,
warmup_requests=args.warmup_requests, warmup_requests=args.warmup_requests,
use_trace_timestamps=args.use_trace_timestamps,
mooncake_slowdown_factor=args.mooncake_slowdown_factor,
mooncake_num_rounds=args.mooncake_num_rounds,
) )
) )
...@@ -1975,6 +2156,7 @@ if __name__ == "__main__": ...@@ -1975,6 +2156,7 @@ if __name__ == "__main__":
"generated-shared-prefix", "generated-shared-prefix",
"mmmu", "mmmu",
"random-image", "random-image",
"mooncake",
], ],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
...@@ -2051,6 +2233,11 @@ if __name__ == "__main__": ...@@ -2051,6 +2233,11 @@ if __name__ == "__main__":
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
) )
parser.add_argument(
"--use-trace-timestamps",
action="store_true",
help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
)
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
...@@ -2174,5 +2361,33 @@ if __name__ == "__main__": ...@@ -2174,5 +2361,33 @@ if __name__ == "__main__":
default=256, default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset", help="Target length in tokens for outputs in generated-shared-prefix dataset",
) )
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
mooncake_group.add_argument(
"--mooncake-slowdown-factor",
type=float,
default=1.0,
help="Slowdown factor for replaying the mooncake trace. "
"A value of 2.0 means the replay is twice as slow. "
"NOTE: --request-rate is IGNORED in mooncake mode.",
)
mooncake_group.add_argument(
"--mooncake-num-rounds",
type=int,
default=1,
help="Number of conversation rounds for each session in the mooncake dataset. "
"A value > 1 will enable true multi-turn session benchmarking.",
)
mooncake_group.add_argument(
"--mooncake-workload",
type=str,
default="conversation",
choices=[
"mooncake",
"conversation",
"synthetic",
"toolagent",
],
help="Underlying workload for the mooncake dataset.",
)
args = parser.parse_args() args = parser.parse_args()
run_benchmark(args) run_benchmark(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment