Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Gemma-2_pytorch
Commits
9a7de7de
Commit
9a7de7de
authored
Nov 06, 2024
by
chenych
Browse files
Modify inference.py and README.
parent
74c7c1cc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
7 deletions
+9
-7
README.md
README.md
+5
-2
inference.py
inference.py
+4
-5
No files found.
README.md
View file @
9a7de7de
...
...
@@ -106,7 +106,10 @@ HIP_VISIBLE_DEVICES=0,1 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_l
### 单机单卡
```
bash
python inference.py
--model_path
/path/of/gemma2
# 指定卡号
export
HIP_VISIBLE_DEVICES
=
0,1
# 根据实际情况修改max_new_tokens参数
python inference.py
--model_path
/path/of/gemma2
--max_new_tokens
xxx
```
## result
...
...
@@ -114,7 +117,7 @@ python inference.py --model_path /path/of/gemma2
-
模型:gemma-2-9b
<div
align=
center
>
<img
src=
"./docs/results.png"
witdh=
1200
height=
400
/
>
<img
src=
"./docs/results.png"
/>
</div>
### 精度
...
...
inference.py
View file @
9a7de7de
...
...
@@ -3,10 +3,8 @@ import argparse
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
# 卡号指定
os
.
environ
[
'HIP_VISIBLE_DEVICES'
]
=
'0'
def
infer_hf
(
model_path
,
input_text
):
def
infer_hf
(
model_path
,
input_text
,
max_new_token
=
32
):
''' transformers 推理 gemma2'''
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
...
...
@@ -16,7 +14,7 @@ def infer_hf(model_path, input_text):
input_ids
=
tokenizer
(
input_text
,
return_tensors
=
"pt"
).
to
(
"cuda"
)
outputs
=
model
.
generate
(
**
input_ids
,
max_new_tokens
=
1024
)
outputs
=
model
.
generate
(
**
input_ids
,
max_new_tokens
=
max_new_token
)
print
(
tokenizer
.
decode
(
outputs
[
0
]))
...
...
@@ -26,10 +24,11 @@ def parse_args():
default
=
'Write me a poem about Machine Learning.'
,
help
=
''
)
parser
.
add_argument
(
'--model_path'
,
default
=
'/path/of/gemma2'
)
parser
.
add_argument
(
'--max_new_tokens'
,
default
=
32
,
type
=
int
)
return
parser
.
parse_args
()
if
__name__
==
'__main__'
:
args
=
parse_args
()
infer_hf
(
args
.
model_path
,
args
.
input_text
)
infer_hf
(
args
.
model_path
,
args
.
input_text
,
args
.
max_new_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment