Commit 3b5f8197 authored by wangsen's avatar wangsen
Browse files

Update cli_demo_hf.py, cli_demo.py files

parent af684142
...@@ -21,7 +21,7 @@ def main(): ...@@ -21,7 +21,7 @@ def main():
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
parser.add_argument("--english", action='store_true', help='only output English') parser.add_argument("--english", action='store_true', help='only output English')
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="/data", help='pretrained ckpt') parser.add_argument("--from_pretrained", type=str, default="THUDM/visualglm-6b", help='pretrained ckpt')
parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round') parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round') parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -4,8 +4,8 @@ import signal ...@@ -4,8 +4,8 @@ import signal
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import torch import torch
tokenizer = AutoTokenizer.from_pretrained("/data", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("/data", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
model = model.eval() model = model.eval()
os_name = platform.system() os_name = platform.system()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment