Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
laibao
llava_vllm
Commits
721cbee2
Commit
721cbee2
authored
Oct 16, 2024
by
laibao
Browse files
No commit message
No commit message
parent
89e36472
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
86 deletions
+5
-86
README.md
README.md
+5
-86
No files found.
README.md
View file @
721cbee2
...
@@ -92,95 +92,14 @@ conda create -n llava_vllm python=3.10
...
@@ -92,95 +92,14 @@ conda create -n llava_vllm python=3.10
### 推理
### 推理
```
import argparse
import os
import torch
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
def run_llava_pixel_values(*, disable_image_processor: bool = False):
llm = LLM(
model="llava/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
disable_image_processor=disable_image_processor,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
if disable_image_processor:
image = torch.load("images/stop_sign_pixel_values.pt")
else:
image = Image.open("/images/stop_sign.jpg") #图片位置
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def run_llava_image_features():
llm = LLM(
model="llava/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImageFeatureData(image),
}, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
main(args)
```
```
bash
```
bash
python examples/
offline_inferenc
e.py
python examples/
llava_exampl
e.py
```
```
其中,
`prompts`
为提示词;
`temperature`
为控制采样随机性的值,值越小模型生成越确定,值变高模型生成更随机,0表示贪婪采样,默认为1;
`max_tokens=16`
为生成长度,默认为1;
为了确保源码能够正常运行,还需要进行以下调整:
`model`
为模型路径;
`tensor_parallel_size=1`
为使用卡数,默认为1;
`dtype="float16"`
为推理数据类型,如果模型权重是bfloat16,需要修改为float16推理,
`quantization="gptq"`
为使用gptq量化进行推理,需下载以上GPTQ模型。
`quantization="awq"`
为使用awq量化进行推理,需下载以上AWQ模型。
*
**去除了AWS CLI 下载逻辑**
:
*
**移除对 `subprocess` 和 `os` 模块的部分依赖**
### 离线批量推理性能测试
### 离线批量推理性能测试
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment