Unverified Commit eb1ae6ae authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add sglang.bench_latency for offline benchmark (#564)

parent 2187f362
"""
Usage:
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4
Reference output:
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141],
device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the United Kingdom is London.
The capital of the United Kingdom is London.
prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742],
device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park to play in the grass and water.
Today is a very
prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0')
"""
import argparse import argparse
import os
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
...@@ -40,7 +63,6 @@ def normal_text(args): ...@@ -40,7 +63,6 @@ def normal_text(args):
@torch.inference_mode() @torch.inference_mode()
def synthetic_tokens(args): def synthetic_tokens(args):
t = AutoTokenizer.from_pretrained(args.model_path)
m = AutoModelForCausalLM.from_pretrained( m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
) )
......
"""
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
# Usage (latency test):
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
# Usage (correctness test):
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
### Reference output:
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
device='cuda:0', dtype=torch.float16)
prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
[-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
[-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
device='cuda:0', dtype=torch.float16)
<s> The capital of France is.
The capital of the United States is Washington, D.C.
<s> The capital of the United Kindom is.
The capital of the United Kingdom is London.
The capital of the
<s> Today is a sunny day and I like go for a walk in the park.
I'm going to the park
"""
import argparse
import dataclasses
import logging
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers
@dataclasses.dataclass
class BenchArgs:
batch_size: int = 1
input_len: int = 1024
output_len: int = 4
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def load_model(server_args, tp_rank):
suppress_other_loggers()
model_config = ModelConfig(path=server_args.model_path)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=28888,
server_args=server_args,
)
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if server_args.tp_size > 1:
dist.barrier()
return model_runner, tokenizer
def prepare_inputs(bench_args, tokenizer):
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][:bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
req.prefix_indices = []
req.sampling_params = sampling_params
req.input_ids = req.origin_input_ids
reqs.append(req)
return input_ids, reqs
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
for i in range(len(reqs)):
req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len:]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, :bench_args.cut_len
]
return reqs
def prepare_synthetic_inputs(bench_args, tokenizer):
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(input_ids)):
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
req.prefix_indices = []
req.sampling_params = sampling_params
req.input_ids = req.origin_input_ids
reqs.append(req)
return reqs
def extend(reqs, model_runner):
batch = Batch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None)
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy())
output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits
def correctness_test(
server_args,
bench_args,
tp_rank,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
# Prepare inputs
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print("prefill logits (first half)", next_token_logits)
# Prepare extend inputs
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
# Extend
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print("prefill logits (final)", next_token_logits)
# Decode
output_ids = [list(req.input_ids) for req in reqs]
for _ in range(bench_args.output_len):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
for i in range(len(reqs)):
output_ids[i].append(next_token_ids[i])
# Print
for i in range(len(reqs)):
print(tokenizer.decode(output_ids[i]))
def latency_test(
server_args,
bench_args,
tp_rank,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
# Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
def clear():
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
@torch.inference_mode()
def run_once(output_len):
# Prefill
torch.cuda.synchronize()
tic = time.time()
next_token_ids, _, batch = extend(reqs, model_runner)
torch.cuda.synchronize()
latency = time.time() - tic
throughput = bench_args.input_len * bench_args.batch_size / latency
rank_print(f"Prefill. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
# Decode
for _ in range(output_len):
torch.cuda.synchronize()
tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
torch.cuda.synchronize()
latency = time.time() - tic
throughput = bench_args.batch_size / latency
rank_print(f"Decode. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
# Warm up
run_once(4)
clear()
# Run again
run_once(bench_args.output_len)
def main(server_args, bench_args):
print(bench_args)
if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
main(server_args, bench_args)
\ No newline at end of file
...@@ -23,6 +23,7 @@ from sglang.srt.server_args import ServerArgs ...@@ -23,6 +23,7 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
is_multimodal_model, is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
) )
...@@ -229,6 +230,7 @@ class ModelRunner: ...@@ -229,6 +230,7 @@ class ModelRunner:
self.nccl_port = nccl_port self.nccl_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
monkey_patch_vllm_dummy_weight_loader()
# Init torch distributed # Init torch distributed
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
......
...@@ -466,6 +466,48 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): ...@@ -466,6 +466,48 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
def monkey_patch_vllm_dummy_weight_loader():
"""
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
"""
from vllm.model_executor.model_loader.loader import (
ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig,
ParallelConfig, SchedulerConfig, CacheConfig, nn,
set_default_torch_dtype, _initialize_model, initialize_dummy_weights,
DummyModelLoader
)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
return model.eval()
setattr(DummyModelLoader, "load_model", load_model)
API_KEY_HEADER_NAME = "X-API-Key" API_KEY_HEADER_NAME = "X-API-Key"
......
import multiprocessing as mp
import time
from dataclasses import dataclass
import torch
import torch.distributed as dist
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
@dataclass
class BenchBatch:
req_to_token_pool: torch.Tensor
token_to_kv_pool: torch.Tensor
input_ids: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
req_pool_indices: torch.Tensor = None
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
def __init__(self, model_runner: ModelRunner):
self.req_to_token_pool = model_runner.req_to_token_pool
self.token_to_kv_pool = model_runner.token_to_kv_pool
def init_prefill_batch(self, input_ids, batch_size, seq_len):
self.input_ids = input_ids
self.position_ids_offsets = torch.zeros(
batch_size, dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.full(
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
)
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
for i in range(batch_size):
n_idx = self.req_pool_indices[i].item()
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
i * seq_len : (i + 1) * seq_len
]
def update_extend(
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
):
self.input_ids = input_ids
self.position_ids_offsets = torch.zeros(
batch_size, dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.full(
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
)
self.prefix_lens = torch.full(
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
)
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
req_to_token = self.req_to_token_pool.req_to_token
fork_num = batch_size // prefix_req_idx.shape[0]
for i in range(batch_size):
p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
)
def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size
assert batch_size == self.req_pool_indices.shape[0]
self.input_ids = predict_ids.reshape(-1)
self.prefix_lens = None
(
self.out_cache_loc,
self.out_cache_cont_start,
self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.out_cache_loc
)
self.seq_lens.add_(1)
def prefill(model_runner: ModelRunner, batch: BenchBatch):
logits, _ = model_runner.forward_extend(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
batch.prefix_lens,
batch.position_ids_offsets,
batch.out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def extend(model_runner: ModelRunner, batch: BenchBatch):
logits, _ = model_runner.forward_extend(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
batch.prefix_lens,
batch.position_ids_offsets,
batch.out_cache_loc,
True,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def decode(model_runner: ModelRunner, batch: BenchBatch):
logits = model_runner.forward_decode(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
None,
batch.position_ids_offsets,
None,
batch.out_cache_cont_start,
batch.out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def bench_generate_worker(
model_path,
tp_rank,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
server_args_dict,
):
assert unique_num % shared_num == 0
model_config = ModelConfig(path=model_path)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=0.8,
tp_rank=tp_rank,
tp_size=tp_size,
nccl_port=28888,
server_args_dict=server_args_dict,
)
batch = BenchBatch(model_runner)
# warm up
for _ in range(1):
input_ids = torch.randint(
low=5, high=100, size=(shared_num * shared_len,)
).cuda()
batch.init_prefill_batch(input_ids, shared_num, shared_len)
_ = prefill(model_runner, batch)
input_ids = torch.randint(
low=5, high=100, size=(unique_num * unique_len,)
).cuda()
batch.update_extend(
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
)
predict_ids = extend(model_runner, batch)
for i in range(decode_len):
predict_ids = torch.from_numpy(predict_ids).cuda()
batch.update_decode(predict_ids, unique_num)
predict_ids = decode(model_runner, batch)
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
if tp_size > 1:
dist.barrier()
prefill_start = time.time()
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
batch.init_prefill_batch(input_ids, shared_num, shared_len)
_ = prefill(model_runner, batch)
if tp_rank == 0:
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
extend_start = time.time()
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
batch.update_extend(
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
)
predict_ids = extend(model_runner, batch)
if tp_rank == 0:
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
for i in range(decode_len):
decode_start = time.time()
predict_ids = torch.from_numpy(predict_ids).cuda()
batch.update_decode(predict_ids, unique_num)
predict_ids = decode(model_runner, batch)
if tp_rank == 0:
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
def bench_generate(
model_path,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
server_args_dict,
):
print(
f"tp_size: {tp_size}, "
f"shared_num: {shared_num}, "
f"unique_num: {unique_num}, "
f"shared_len: {shared_len}, "
f"unique_len: {unique_len}, "
f"decode_len: {decode_len}, "
f"server_args: {server_args_dict}"
)
workers = []
for tp_rank in range(tp_size):
proc = mp.Process(
target=bench_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
server_args_dict,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
bench_generate(
model_path="meta-llama/Llama-2-7b-chat-hf",
tp_size=1,
shared_num=1,
unique_num=32,
shared_len=256,
unique_len=256,
decode_len=8,
server_args_dict={},
)
import multiprocessing
import os
import time
import numpy as np
import torch
import torch.distributed as dist
import transformers
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
def test_generate_worker(model_path, tp_rank, tp_size):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
# Input
prompts = [
"The capital of France is",
"Today is a sunny day and I like",
]
sampling_params = SamplingParams(temperature=0)
cut_num = 4
reqs = []
for i in range(len(prompts)):
input_ids = tokenizer.encode(prompts[i])[:cut_num]
req = Req(i, prompts[i], input_ids)
req.sampling_params = sampling_params
reqs.append(req)
# Prefill
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits)
# Extend
for i in range(len(prompts)):
req = reqs[i]
req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num
]
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits", logits)
print(
"next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
)
# Decode
for i in range(6):
batch.prepare_for_decode(next_token_ids.cpu().numpy())
logits, _ = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
print(
"next_token_ids",
next_token_ids,
[tokenizer.decode(x) for x in next_token_ids],
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)
# Reference output for TinyLlama-1.1B-Chat-v0.4
# extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633],
# [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]],
# device='cuda:0', dtype=torch.float16)
# extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121],
# [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]],
# device='cuda:0', dtype=torch.float16)
# next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to']
# next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go']
# next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
# next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a']
# next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
# next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
# next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
def test_generate_worker(
model_path, tp_rank, tp_size, batch_size, input_len, output_len
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# Prepare data
input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)])
input_ids = input_ids.reshape(-1)
input_ids = torch.tensor(input_ids).cuda()
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full(
(batch_size,), input_len, dtype=torch.int32, device="cuda"
)
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(print_logits):
nonlocal predict_ids
logits, _ = model.forward_prefill(
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
def decode(print_logits):
nonlocal predict_ids
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits, _ = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", i, logits)
# Warm up
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
predict_ids = None
prefill(True)
for i in range(output_len):
decode(True)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.token_to_kv_pool.dec_refs(
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
)
model.req_to_token_pool.free(req_pool_indices)
# Benchmark
if tp_size > 1:
dist.barrier()
start_time = prefill_start_time = time.time()
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
prefill(False)
if tp_rank == 0:
print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms")
for i in range(output_len):
step_start = time.time()
decode(False)
step_end = time.time()
if i % 100 == 0 or i == output_len - 1:
if tp_rank == 0:
print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms")
end_time = time.time()
if tp_rank == 0:
print(f"total cost: {(end_time - start_time) * 1000:.2f}")
def test_generate(model_path, tp_size, batch_size, input_len, output_len):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
batch_size,
input_len,
output_len,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8)
# test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8)
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8)
# prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00,
# 1.8281e+00, 2.1816e+00]], device='cuda:0')
# decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]],
# device='cuda:0')
# decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]],
# device='cuda:0')
# decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]],
# device='cuda:0')
# decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]],
# device='cuda:0')
# decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]],
# device='cuda:0')
# decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]],
# device='cuda:0')
# decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]],
# device='cuda:0')
# decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]],
# device='cuda:0')
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8)
# prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]],
# device='cuda:0')
# decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]],
# device='cuda:0')
# decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]],
# device='cuda:0')
# decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]],
# device='cuda:0')
# decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]],
# device='cuda:0')
# decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]],
# device='cuda:0')
# decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]],
# device='cuda:0')
# decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]],
# device='cuda:0')
# decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]],
# device='cuda:0')
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.utils import load_image
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda")
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(model, tp_rank, params, print_logits):
logits, _ = model.forward_extend_multi_modal(
*params,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
return predict_ids
def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits):
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = params
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits, _ = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", step, logits)
return predict_ids
def test_generate_worker(
model_path,
tp_rank,
tp_size,
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# print(model.model)
# Prepare data
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:"
image_path = "/home/ubuntu/sglang/test/lang/test_image.png"
image = load_image(image_path)
processor = get_processor("llava-hf/llava-1.5-7b-hf")
input_ids = processor.tokenizer.encode(prompt)
pixel_values = processor.image_processor(image)["pixel_values"]
input_ids, offset = model.model.pad_input_ids(
input_ids,
[
0,
],
)
params = init_batch_data(model, 1, len(input_ids))
# inference
output_ids = []
prefill_params = (
torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values),
[None],
[offset],
*params,
)
predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False)
output_ids.append(predict_ids[0][0])
for i in range(16):
predict_ids = decode(
i,
model,
tp_rank=0,
batch_size=1,
predict_ids=predict_ids,
params=params,
print_logits=False,
)
output_ids.append(predict_ids[0][0])
# detokenization
output = processor.tokenizer.batch_decode(
[output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
assert (
output
== "The image features a man standing on the back of a yellow taxi cab, holding"
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("liuhaotian/llava-v1.5-7b", 1)
""" """
Usage: Usage:
python3 -m sglang.launch_server --model-path /model/llama-classification
python3 test_httpserver_classify.py python3 test_httpserver_classify.py
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment