Commit 118f1fc7 authored by maxiao1's avatar maxiao1
Browse files

sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

parents
import itertools
from typing import Optional, Tuple, Union
import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from vllm import _custom_ops as vllm_ops
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
naive_norm.weight = nn.Parameter(weight)
naive_norm = naive_norm.to(x.device)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
output = naive_norm(x, residual)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
fused_add_rmsnorm(x, residual, weight, eps)
output = (x, residual)
else:
output = rmsnorm(x, weight, eps)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
output = (x, residual)
else:
out = torch.empty_like(x)
vllm_ops.rms_norm(out, x, weight, eps)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
if torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"],
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
args={},
)
)
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
quantiles = [0.5, 0.2, 0.8]
if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection"
)
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/rmsnorm/",
help="Path to save rmsnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
)
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
import os
import torch
import triton
import triton.language as tl
@torch.compile(dynamic=True)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
def test_get_last_loc():
max_batch = 4097
max_context_len = 6148
batch_size = 20
# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
pre_lens = torch.randint(
-max_context_len // 2,
max_context_len,
(batch_size,),
dtype=torch.int64,
device="cuda",
)
last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)
last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)
# Compare results
torch.testing.assert_close(last_loc_res, last_loc_ref)
def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=batch_sizes,
line_arg="provider",
line_vals=["reference", "triton"],
line_names=["PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="get-last-loc-performance",
args={},
)
)
def benchmark(batch_size, provider):
max_batch = 2048
max_context_len = 16384
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
pre_lens = torch.randint(
-max_context_len // 2,
max_context_len,
(batch_size,),
dtype=torch.int64,
device="cuda",
)
quantiles = [0.5, 0.2, 0.8]
if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"):
"""Run benchmark and save results"""
# Ensure save path exists
os.makedirs(save_path, exist_ok=True)
# Run correctness test
test_get_last_loc()
print("Correctness test passed!")
# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/get_last_loc/",
help="Path to save benchmark results",
)
args = parser.parse_args()
run_benchmark(args.save_path)
import itertools
import os
import torch
import triton
import triton.language as tl
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
@triton.jit
def write_req_to_token_pool_triton_optimize(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_token = tl.program_id(1)
req_pool_index = tl.load(req_pool_indices + pid_batch)
pre_len = tl.load(pre_lens + pid_batch)
seq_len = tl.load(seq_lens + pid_batch)
extend_len = seq_len - pre_len
cumsum_start = 0
for i in range(pid_batch):
cumsum_start += tl.load(extend_lens + i)
token_start = pid_token * BLOCK_SIZE
offset = tl.arange(0, BLOCK_SIZE)
actual_offset = token_start + offset
mask = actual_offset < extend_len
src_ptr = out_cache_loc + cumsum_start + actual_offset
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
value = tl.load(src_ptr, mask=mask)
dst_ptr = (
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ actual_offset
+ pre_len
)
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
tl.store(dst_ptr, value, mask=mask)
def write_req_to_token_pool_reference(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
pre_lens: torch.Tensor,
seq_lens: torch.Tensor,
extend_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
) -> None:
"""Reference implementation using PyTorch"""
for i in range(len(req_pool_indices)):
req_pool_idx = req_pool_indices[i].item()
pre_len = pre_lens[i].item()
seq_len = seq_lens[i].item()
extend_len = extend_lens[i].item()
cumsum_start = sum(extend_lens[:i].tolist())
# Copy values from out_cache_loc to req_to_token
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
cumsum_start : cumsum_start + extend_len
]
def test_write_req_to_token_pool():
max_batch = 4097
max_context_len = 6148
batch_size = 1
extend_len = 14
# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
# Create copies for reference implementation
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()
# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)
# Run optimized triton kernel
def grid(batch_size, extend_len):
num_token_blocks = triton.cdiv(extend_len, 512)
return (batch_size, num_token_blocks)
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)
# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)
# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
# Test case 2: batch size > 1
batch_size = 3
extend_lens_list = [14, 20, 30]
total_extend_len = sum(extend_lens_list)
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()
# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)
# Run optimized triton kernel
max_extend_len = max(extend_lens_list)
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)
# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)
# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(batch_sizes, extend_lens))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "extend_len"],
x_vals=configs,
line_arg="provider",
line_vals=["reference", "triton", "triton_optimize"],
line_names=["PyTorch", "Triton", "Triton Optimized"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="write-req-to-token-pool-performance",
args={},
)
)
def benchmark(batch_size, extend_len, provider):
max_batch = 256
max_context_len = 16384
extend_lens_list = [extend_len] * batch_size
total_extend_len = sum(extend_lens_list)
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
seq_lens = pre_lens + extend_len
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_reference(
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_triton[(batch_size,)](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
),
quantiles=quantiles,
)
else:
def run_optimized():
block_size = 128 if extend_len <= 1024 else 512
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
write_req_to_token_pool_triton_optimize[grid_config](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(
run_optimized, quantiles=quantiles
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
"""Run benchmark and save results"""
# Ensure save path exists
os.makedirs(save_path, exist_ok=True)
# Run correctness test
test_write_req_to_token_pool()
print("Correctness test passed!")
# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/write_req_to_token_pool/",
help="Path to save benchmark results",
)
args = parser.parse_args()
run_benchmark(args.save_path)
import itertools
import torch
import torch.nn.functional as F
import triton.testing as tt
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
def extend_attention_fwd_torch(
q: torch.Tensor, # [extend_tokens, H_Q, D]
k: torch.Tensor, # [extend_tokens, H_KV, D]
v: torch.Tensor, # [extend_tokens, H_KV, D]
o: torch.Tensor, # [extend_tokens, H_Q, D]
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
qo_indptr: torch.Tensor, # [B+1]
kv_indptr: torch.Tensor, # [B+1]
kv_indices: torch.Tensor, # [prefix_tokens]
sliding_window_size: int,
):
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = 1.0 / D**0.5
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
if group_size != 1:
k_full_hq = k_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
v_full_hq = v_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
# causal
pos_keys = torch.arange(total_len, device=q.device)
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
# sliding window
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = (
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
) # [extend_len, H_Q, total_len]
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
def _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
):
b_seq_len_prefix = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len_extend = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
)
for i in range(B):
s = kv_indptr[i].item()
e = kv_indptr[i + 1].item()
kv_indices[s:e] = torch.arange(
b_start_loc[i],
b_start_loc[i] + b_seq_len_prefix[i],
dtype=torch.int32,
device=device,
)
total_token_num = int(torch.sum(b_seq_len).item())
extend_token_num = int(torch.sum(b_seq_len_extend).item())
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
o_extend_triton = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device=device
)
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
inputs = dict(
q_extend=q_extend,
k_extend=k_extend,
v_extend=v_extend,
k_buffer=k_buffer,
v_buffer=v_buffer,
o_extend_triton=o_extend_triton,
o_extend_torch=o_extend_torch,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
max_len_extend=max_len_extend,
WINDOW_SIZE=WINDOW_SIZE,
)
meta = dict(
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
)
return inputs, meta
def _run_triton(inputs):
extend_attention_fwd(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_triton"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=inputs["max_len_extend"],
sliding_window_size=inputs["WINDOW_SIZE"],
)
def _run_torch_ref(inputs):
extend_attention_fwd_torch(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_torch"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
inputs["WINDOW_SIZE"],
)
N_CTXS = [1024, 2048, 4096, 8192]
WINDOW_SIZES = [-1, 127, 256, 512]
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
PROVIDERS = ["torch", "triton"]
@tt.perf_report(
tt.Benchmark(
x_names=["N_CTX", "WINDOW_SIZE"],
x_vals=CONFIGS,
line_arg="provider",
line_vals=PROVIDERS,
line_names=PROVIDERS,
ylabel="Runtime (ms)",
plot_name="extend_attention_triton_vs_torch",
args={
"B": 32,
"H_Q": 64,
"H_KV": 8,
"D": 128,
"dtype": "bf16",
"device": "cuda",
"check_correctness": False,
"warmup": 25,
"rep": 100,
},
)
)
def bench(
N_CTX,
provider,
B,
H_Q,
H_KV,
D,
dtype,
device,
WINDOW_SIZE,
check_correctness,
warmup,
rep,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
dt = dtype_map[dtype]
inputs, _ = _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
)
if check_correctness and provider == "triton":
_run_triton(inputs)
_run_torch_ref(inputs)
torch.cuda.synchronize()
if not torch.allclose(
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
):
raise AssertionError("Mismatch between triton and torch reference.")
if provider == "triton":
ms = tt.do_bench(lambda: _run_triton(inputs), warmup=warmup, rep=rep)
elif provider == "torch":
ms = tt.do_bench(lambda: _run_torch_ref(inputs), warmup=warmup, rep=rep)
else:
raise ValueError(provider)
return ms
if __name__ == "__main__":
bench.run(print_data=True, show_plots=False)
## Download data
```
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
python3 gen_data.py --number 1000
```
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
```
```
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
```
###
```
# original
Accuracy: 0.940, latency: 332.83 s
# parallel encoding (no_adjust, offset = 1000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 3000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 0)
Accuracy: 0.520, latency: 238.46 s
# parallel encoding (adjust_cache)
Accuracy: 0.460, latency: 257.66 s
```
import argparse
import json
import re
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@sgl.function
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
s += prefix + "\n"
contexts = [body_0, body_1, body_2, body_3]
position_ids_offset = [i * 1000 for i in range(len(contexts))]
forks = s.fork(len(contexts), position_ids_offset)
forks += lambda i: contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "\n" + suffix
s += sgl.gen("answer", max_tokens=16)
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
arguments = []
labels = []
sum_src_indices = []
sum_dst_indices = []
for i in range(len(src_indices)):
for j in range(len(dst_percents)):
src_index = src_indices[i]
dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [
q
for q in query_indices
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
]
dst_index = query_indices[
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
]
label = line_obj["values"][dst_index]
body = line_obj["lines"][: src_index + 1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4
arguments.append(
{
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len :]),
"suffix": suffix,
}
)
labels.append(label)
sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index)
# Select backend
backend = select_sglang_backend(args)
tic = time.perf_counter()
states = line_retrieval.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
corrects = []
for i in range(len(arguments)):
output = states[i]["answer"]
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
label = labels[i]
# Try all numbers
findall = re.findall("\d+", output)
if not findall:
response_number = output
else:
for response_number in findall:
if response_number == label:
break
correct = response_number == label
corrects.append(correct)
# Log results
summary = (
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
f"Prompt len: {prompt_len}, "
f"Correct: {correct}, "
f"Label: {label}, Predicted: {response_number}, "
)
print(summary)
accuracy = np.mean(corrects)
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "line_retrieval",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"num_questions": len(arguments),
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
def main(args):
line_obj = json.load(open(args.data_path, "r"))
num_hoops = args.num_hoops
for src_index in args.src_index:
src_indices = [src_index]
num_queries = args.num_queries_per_src
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
parser.add_argument("--num-queries-per-src", type=int, default=10)
parser.add_argument("--num-hoops", type=int, default=1)
args = add_common_sglang_args_and_parse(parser)
main(args)
"""
Generate line data for line retrieval task.
Usage:
python3 gen_data.py --number 1000
"""
import argparse
import json
from collections import defaultdict
import numpy as np
from tqdm import tqdm
def generate_lines(random_words, num_lines, redirect_ratio):
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resolving the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
# Raw lines
visited_indices = set([None])
visited_values = set([None])
lines = []
redirects = []
indices = []
values = []
for i in tqdm(range(num_lines)):
line_index = None
while line_index in visited_indices:
line_index = "-".join(np.random.choice(random_words, size=(2,)))
visited_indices.add(line_index)
line_value = np.random.randint(low=0, high=999999)
line_value = f"{line_value:06}"
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
lines.append(line)
redirects.append(None)
indices.append(line_index)
values.append(line_value)
# Add redirect
if redirect_ratio > 0:
num_redirect_lines = int(len(lines) * redirect_ratio)
redirect_indices = np.random.choice(
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
)
for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = (
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
)
redirects[i] = target_idx
# Build links and find sources
links = [[] for _ in range(num_lines)]
contains_ring = set()
for i in range(num_lines):
if redirects[i] is None:
continue
tmp_link = []
cur = i
visited = set()
while redirects[cur] is not None:
visited.add(cur)
tmp_link.append(redirects[cur])
cur = redirects[cur]
if cur in visited:
contains_ring.add(i)
tmp_link = None
break
values[i] = values[cur]
links[i] = tmp_link
# Group by num_links
group_by_num_hoops = defaultdict(list)
for i in range(num_lines):
if i in contains_ring:
continue
group_by_num_hoops[len(links[i]) + 1].append(i)
keys = sorted(list(group_by_num_hoops.keys()))
for num_links in keys:
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
# Append few-shot examples
hoop1_candidates = list(group_by_num_hoops[1])
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
hoop2_candidates = list(group_by_num_hoops[2])
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
i = hoop1_candidates[5]
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
if len(hoop2_candidates):
i = hoop2_candidates[0]
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
i = hoop2_candidates[1]
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
else:
i = hoop1_candidates[1]
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
i = hoop1_candidates[10]
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
obj = {
"prefix": prefix,
"suffix": suffix,
"lines": lines,
"indices": indices,
"values": values,
"links": links,
"group_by_num_hoops": group_by_num_hoops,
"contains_ring": sorted(list(contains_ring)),
}
return obj
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--number", type=int)
parser.add_argument("--redirect-ratio", type=float, default=0.0)
args = parser.parse_args()
num_lines = args.number
random_words_filename = "random_words.json"
random_words = json.load(open(random_words_filename, "r"))
np.random.seed(42)
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
with open(fout, "w") as fout:
json.dump(obj, fout, indent=2)
## Download benchmark images
```
python3 download_images.py
```
image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
### Other Dependency
```
pip3 install "sglang[all]"
pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
```
## Run benchmark
### Benchmark sglang
Launch a server
```
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
```
Run benchmark
```
# Run with local models
python3 bench_sglang.py --num-questions 60
# Run with OpenAI models
python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
```
### Bench LLaVA original code
```
git clone git@github.com:haotian-liu/LLaVA.git
cd LLaVA
git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
pip3 install -e .
cd ~/sglang/benchmark/llava_bench
CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
```
### Benchmark llama.cpp
```
# Install
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
pip install sse_starlette starlette_context pydantic_settings
# Download weights
mkdir -p ~/model_weights/llava-v1.5-7b/
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
```
```
python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
```
#!/bin/bash
python -m llava.eval.model_vqa \
--model-path liuhaotian/llava-v1.5-7b \
--question-file ./questions.jsonl \
--image-folder ./images \
--answers-file ./answers_hf.jsonl \
--temperature 0 \
--conv-mode vicuna_v1
#!/bin/bash
python -m llava.eval.model_vqa_loader \
--model-path liuhaotian/llava-v1.5-7b \
--question-file ./mme_pack/llava_mme_bench_replace.jsonl \
--image-folder ./mme_pack/MME_Benchmark_release_version \
--answers-file ./answers_hf_mme.jsonl \
--temperature 0 \
--conv-mode vicuna_v1
import argparse
import json
import os
import time
import tqdm
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function
def image_qa(s, image_file, question):
s += sgl.user(sgl.image(image_file) + question)
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
def main(args):
lines = list(read_jsonl(args.question_file))[: args.num_questions]
arguments = [
{
"image_file": os.path.abspath(args.image_folder + "/" + l["image"]),
"question": l["text"],
}
for l in lines
]
# arguments = [
# {"image_file":
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
# "question": l["text"]} for l in lines
# ]
states = [None] * len(lines)
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm.tqdm(range(len(lines))):
image_file = arguments[i]["image_file"]
question = arguments[i]["question"]
ret = image_qa.run(image_file=image_file, question=question, temperature=0)
states[i] = ret
else:
states = image_qa.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.perf_counter() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
print(f"Write output to {args.answer_file}")
with open(args.answer_file, "w") as fout:
for i in range(len(lines)):
value = {
"question_id": lines[i]["question_id"],
"prompt": lines[i]["text"],
"text": states[i]["answer"].strip(),
"model_id": backend.model_info["model_path"],
"answer_id": i,
"metadata": {},
}
fout.write(json.dumps(value) + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "llava_bench",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(lines),
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--question-file", type=str, default="questions.jsonl")
parser.add_argument("--answer-file", type=str, default="answers.jsonl")
parser.add_argument("--image-folder", type=str, default="./images")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--num-questions", type=int, default=None)
parser.add_argument("--max-tokens", type=int, default=768)
args = add_common_sglang_args_and_parse(parser)
main(args)
MME_FOLDER=./mme_pack
python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4
import os
# Create the 'images' directory if it doesn't exist
if not os.path.exists("images"):
os.makedirs("images")
# Base URL
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
# Loop through image numbers
for i in range(1, 25):
# Format the image number with leading zeros
image_number = str(i).zfill(3)
image_url = base_url + image_number + ".jpg"
image_path = "images/" + image_number + ".jpg"
# Download the image using wget
os.system(f"wget -O {image_path} {image_url}")
print("Download complete.")
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 25 --parallel 8
python3 bench_sglang.py --num-questions 16 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --backend vllm --num-questions 25
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --backend lmql --num-questions 25 --parallel 1
```
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
dimension_prompts = [
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
]
def multi_dimension_judge(article, generate):
s = system_prompt
s += "\n```\n" + article + "\n```\n\n"
judges = []
for i in range(len(dimension_prompts)):
comp = generate(
s
+ "USER: Please judge the quality based on the following metric. "
+ dimension_prompts[i]
+ " Please provide a single-paragraph judgement. "
+ "Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:',
max_tokens=256,
stop="END",
)
judges.append(comp)
s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)):
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += generate(s, max_tokens=2, stop=None)
return s
async def multi_dimension_judge_async(article, generate):
s = system_prompt
s += "\n```\n" + article + "\n```\n\n"
judges = []
for i in range(len(dimension_prompts)):
comp = await generate(
s
+ "USER: Please judge the quality based on the following metric. "
+ dimension_prompts[i]
+ " Please provide a single-paragraph judgement. "
+ "Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:',
max_tokens=256,
stop="END",
)
judges.append(comp)
s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)):
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += await generate(s, max_tokens=2, stop=None)
return s
def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions]
states = [None] * len(lines)
# Select backend
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
tic = time.perf_counter()
if args.backend != "lmql":
def get_one_answer(i):
states[i] = multi_dimension_judge(lines[i], call_generate)
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(lines)))),
total=len(lines),
)
)
else:
import asyncio
async def get_one_answer_async(i):
states[i] = await multi_dimension_judge_async(lines[i], call_generate)
batches = []
for i in range(0, len(lines), args.parallel):
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
loop = asyncio.get_event_loop()
for bt in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in bt])
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "llm_judge",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="articles.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
dimension_prompts = [
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
]
@sgl.function
def multi_dimension_judge(s, article):
s += system_prompt
s += "\n```\n" + article + "\n```\n\n"
forks = s.fork(len(dimension_prompts))
for i in range(len(dimension_prompts)):
forks[i] += (
"USER: Please judge the quality based on the following metric. "
+ dimension_prompts[i]
+ " Please provide a single-paragraph judgement. "
+ "Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:'
)
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
forks.join()
s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)):
s += (
dimension_prompts[i].split(":")[0]
+ ": "
+ forks[i]["judgement"].strip()
+ "\n"
)
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += sgl.gen("score", max_tokens=2)
def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{"article": l} for l in lines]
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.perf_counter()
states = multi_dimension_judge.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "llm_judge",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="articles.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_sglang_args_and_parse(parser)
main(args)
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 5 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
```
```
python3 bench_other.py --backend vllm --num-questions 5
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
```
### Build dataset
```
pip install wikipedia
python3 build_dataset.py
```
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
def json_decode(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n"
s += ' "name": "'
s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "country": "'
s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "air port code": "'
s += generate(s, max_tokens=8, stop='"') + '",\n'
s += ' "top 3 landmarks": "'
s += generate(s, max_tokens=24, stop='"') + '",\n'
s += "}\n"
return s
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
states = [None] * len(arguments)
# Select backend
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
def get_one_answer(i):
states[i] = json_decode(generate=call_generate, **arguments[i])
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "long_json_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function
def json_decode(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n"
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
s += (
' "air port code": "'
+ sgl.gen("air port code", max_tokens=8, stop='"')
+ '",\n'
)
s += (
' "top 3 landmarks": "'
+ sgl.gen("landmarks", max_tokens=24, stop='"')
+ '",\n'
)
s += "}\n"
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.perf_counter()
states = json_decode.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "long_json_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=10)
args = add_common_sglang_args_and_parse(parser)
main(args)
import json
import transformers
import wikipedia
name = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(name)
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
for city_name in city_names:
content = str(wikipedia.page(city_name).content)
content = content.replace("\n\n", "\n")
tokens = t.encode(content)
truncate_len = int((10000 / len(tokens)) * len(content))
truncate_content = content[:truncate_len]
truncate_tokens = t.encode(truncate_content)
# Count token
print(
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
)
with open("questions.jsonl", "a") as fout:
fout.write(json.dumps({"document": truncate_content}) + "\n")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment