Commit 1190e964 authored by zhuwenwen's avatar zhuwenwen
Browse files

update benchmarks and examples

parent 69185c0b
......@@ -5,11 +5,13 @@ import random
import time
from typing import List, Optional, Tuple
import numpy as np
import torch
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
......@@ -119,6 +121,23 @@ def run_vllm(
max_tokens=output_len,
))
# warmup
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.num_prompts,
args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion():
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion()
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
......@@ -295,6 +314,10 @@ if __name__ == "__main__":
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--num-prompts",
type=int,
default=1000,
......
......@@ -12,7 +12,7 @@ if __name__ == '__main__':
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=False)
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
......
......@@ -332,7 +332,7 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das1.1.git' + sha[:7]
version = 'das1.2.git' + sha[:7]
# abi version
version += "." + get_abi()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment