Commit 721cbee2 authored by laibao's avatar laibao
Browse files

No commit message

No commit message
parent 89e36472
...@@ -92,95 +92,14 @@ conda create -n llava_vllm python=3.10 ...@@ -92,95 +92,14 @@ conda create -n llava_vllm python=3.10
### 推理 ### 推理
```
import argparse
import os
import torch
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
def run_llava_pixel_values(*, disable_image_processor: bool = False):
llm = LLM(
model="llava/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
disable_image_processor=disable_image_processor,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
if disable_image_processor:
image = torch.load("images/stop_sign_pixel_values.pt")
else:
image = Image.open("/images/stop_sign.jpg") #图片位置
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def run_llava_image_features():
llm = LLM(
model="llava/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImageFeatureData(image),
}, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
main(args)
```
```bash ```bash
python examples/offline_inference.py python examples/llava_example.py
``` ```
其中,`prompts`为提示词;`temperature`为控制采样随机性的值,值越小模型生成越确定,值变高模型生成更随机,0表示贪婪采样,默认为1;`max_tokens=16`为生成长度,默认为1; 为了确保源码能够正常运行,还需要进行以下调整:
`model`为模型路径;`tensor_parallel_size=1`为使用卡数,默认为1;`dtype="float16"`为推理数据类型,如果模型权重是bfloat16,需要修改为float16推理,`quantization="gptq"`为使用gptq量化进行推理,需下载以上GPTQ模型。`quantization="awq"`为使用awq量化进行推理,需下载以上AWQ模型。
* **去除了AWS CLI 下载逻辑**
* **移除对 `subprocess` 和 `os` 模块的部分依赖**
### 离线批量推理性能测试 ### 离线批量推理性能测试
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment