"ssh:/git@developer.sourcefind.cn:2222/gaoqiong/pybind11.git" did not exist on "64f2a5f8e699736f528b6eb3fa143492b65da93a"
Commit 1190e964 authored by zhuwenwen's avatar zhuwenwen
Browse files

update benchmarks and examples

parent 69185c0b
...@@ -5,11 +5,13 @@ import random ...@@ -5,11 +5,13 @@ import random
import time import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...@@ -119,6 +121,23 @@ def run_vllm( ...@@ -119,6 +121,23 @@ def run_vllm(
max_tokens=output_len, max_tokens=output_len,
)) ))
# warmup
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.num_prompts,
args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion():
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion()
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
...@@ -295,6 +314,10 @@ if __name__ == "__main__": ...@@ -295,6 +314,10 @@ if __name__ == "__main__":
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=1000, default=1000,
......
...@@ -12,7 +12,7 @@ if __name__ == '__main__': ...@@ -12,7 +12,7 @@ if __name__ == '__main__':
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=False) llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -332,7 +332,7 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -332,7 +332,7 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown': if sha != 'Unknown':
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
version = 'das1.1.git' + sha[:7] version = 'das1.2.git' + sha[:7]
# abi version # abi version
version += "." + get_abi() version += "." + get_abi()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment