Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenzhuo
XVERSE-MoE-A4.2B
Commits
2cd61853
Commit
2cd61853
authored
Apr 12, 2024
by
chenzhuo
Browse files
Upload New File
parent
2ee68703
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
0 deletions
+82
-0
text_generation_demo.py
text_generation_demo.py
+82
-0
No files found.
text_generation_demo.py
0 → 100644
View file @
2cd61853
import
argparse
import
torch
import
gradio
as
gr
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
tokenizer
,
model
=
None
,
None
def
init_model
(
args
):
global
tokenizer
,
model
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer_path
,
truncation_side
=
"left"
,
padding_side
=
"left"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
args
.
model_path
,
trust_remote_code
=
True
,
torch_dtype
=
torch
.
bfloat16
,
device_map
=
'auto'
)
model
=
model
.
eval
()
def
batch_call
(
texts
,
skip_special_tokens
=
True
,
**
kwargs
):
tokenized
=
tokenizer
(
texts
,
padding
=
True
,
return_tensors
=
"pt"
)
inputs
=
{
key
:
value
.
cuda
()
for
key
,
value
in
tokenized
.
items
()
if
key
!=
'token_type_ids'
}
generate_ids
=
model
.
generate
(
**
inputs
,
**
kwargs
)
output
=
[]
for
tok
,
gen
in
zip
(
tokenized
.
input_ids
,
generate_ids
):
generated
=
tokenizer
.
decode
(
gen
[
len
(
tok
):],
skip_special_tokens
=
skip_special_tokens
)
output
.
append
(
generated
)
return
output
def
text_generation
(
texts
,
max_new_tokens
,
temperature
,
top_k
,
top_p
):
output
=
batch_call
(
texts
,
max_new_tokens
=
max_new_tokens
,
do_sample
=
True
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
eos_token_id
=
tokenizer
.
eos_token_id
)
return
output
[
0
]
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
20014
,
help
=
"server port"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"./model"
,
help
=
"Path to the model. Specifies the file path to the pre-trained model to be used for text generation."
)
parser
.
add_argument
(
"--tokenizer_path"
,
type
=
str
,
default
=
"./model"
,
help
=
"Path to the tokenizer."
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
"__main__"
:
args
=
get_args
()
# initialize model and tokenizer
init_model
(
args
)
with
gr
.
Blocks
()
as
demo
:
gr
.
Markdown
(
"# <center>{}</center>"
.
format
(
"XVERSE-MoE-25B Text Generation"
))
with
gr
.
Row
():
with
gr
.
Column
():
inputs
=
gr
.
inputs
.
Textbox
(
lines
=
5
,
label
=
"Input Text"
)
# input
with
gr
.
Column
():
max_new_tokens
=
gr
.
Slider
(
maximum
=
512
,
value
=
100
,
minimum
=
1
,
step
=
1
,
label
=
"max_new_tokens"
,
interactive
=
True
)
# max_new_tokens
temperature
=
gr
.
Slider
(
maximum
=
1.0
,
value
=
1.0
,
minimum
=
0.0
,
step
=
0.05
,
label
=
'temperature'
,
interactive
=
True
)
# temperature
top_k
=
gr
.
Slider
(
maximum
=
50
,
value
=
50
,
minimum
=
0
,
step
=
1
,
label
=
'Top K'
,
interactive
=
True
)
# top_k
top_p
=
gr
.
Slider
(
maximum
=
1
,
value
=
0.92
,
minimum
=
0
,
step
=
0.02
,
label
=
'Top P'
,
interactive
=
True
)
# top_p
with
gr
.
Row
():
outputs
=
gr
.
inputs
.
Textbox
(
lines
=
2
,
label
=
"Output Text"
)
with
gr
.
Row
():
submit_btn
=
gr
.
Button
(
value
=
"生成"
,
variant
=
"secondary"
)
reset_btn
=
gr
.
ClearButton
(
components
=
[
inputs
,
outputs
],
value
=
"清除"
,
variant
=
"secondary"
)
submit_btn
.
click
(
fn
=
text_generation
,
inputs
=
[
inputs
,
max_new_tokens
,
temperature
,
top_k
,
top_p
],
outputs
=
outputs
)
demo
.
launch
(
server_name
=
"0.0.0.0"
,
server_port
=
args
.
port
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment