Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e5887034
Unverified
Commit
e5887034
authored
Mar 07, 2023
by
BlueRum
Committed by
GitHub
Mar 07, 2023
Browse files
[chatgpt]fix inference model load (#2988)
* fix lora bug * polish * fix lora gemini * fix inference laod model bug
parent
82503a96
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
9 deletions
+13
-9
applications/ChatGPT/examples/README.md
applications/ChatGPT/examples/README.md
+5
-2
applications/ChatGPT/examples/inference.py
applications/ChatGPT/examples/inference.py
+8
-7
No files found.
applications/ChatGPT/examples/README.md
View file @
e5887034
...
@@ -69,10 +69,13 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy
...
@@ -69,10 +69,13 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy
## Inference example(After Stage3)
## Inference example(After Stage3)
We support naive inference demo after training.
We support naive inference demo after training.
```
shell
```
shell
# inference
# inference, using pretrain path to configure model
python inference.py
--pretrain
<your actor model path>
--model
<your model
type
>
python inference.py
--model_path
<your actor model path>
--model
<your model
type
>
--pretrain
<your pretrain model name/path>
# example
python inference.py
--model_path
./actor_checkpoint_prompts.pt
--pretrain
bigscience/bloom-560m
--model
bloom
```
```
#### data
#### data
-
[
x] [rm-static
](
https://huggingface.co/datasets/Dahoas/rm-static
)
-
[
x] [rm-static
](
https://huggingface.co/datasets/Dahoas/rm-static
)
-
[
x] [hh-rlhf
](
https://huggingface.co/datasets/Anthropic/hh-rlhf
)
-
[
x] [hh-rlhf
](
https://huggingface.co/datasets/Anthropic/hh-rlhf
)
...
...
applications/ChatGPT/examples/inference.py
View file @
e5887034
import
argparse
import
argparse
import
torch
import
torch
from
chatgpt.nn
import
BLOOMActor
,
GPTActor
,
OPTActor
from
chatgpt.nn
import
BLOOMActor
,
GPTActor
,
OPTActor
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
...
@@ -9,18 +9,17 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
...
@@ -9,18 +9,17 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
def
eval
(
args
):
def
eval
(
args
):
# configure model
# configure model
if
args
.
model
==
'gpt2'
:
if
args
.
model
==
'gpt2'
:
actor
=
GPTActor
().
to
(
torch
.
cuda
.
current_device
())
actor
=
GPTActor
(
pretrained
=
args
.
pretrain
).
to
(
torch
.
cuda
.
current_device
())
elif
args
.
model
==
'bloom'
:
elif
args
.
model
==
'bloom'
:
actor
=
BLOOMActor
().
to
(
torch
.
cuda
.
current_device
())
actor
=
BLOOMActor
(
pretrained
=
args
.
pretrain
).
to
(
torch
.
cuda
.
current_device
())
elif
args
.
model
==
'opt'
:
elif
args
.
model
==
'opt'
:
actor
=
OPTActor
().
to
(
torch
.
cuda
.
current_device
())
actor
=
OPTActor
(
pretrained
=
args
.
pretrain
).
to
(
torch
.
cuda
.
current_device
())
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
state_dict
=
torch
.
load
(
args
.
pretrain
)
state_dict
=
torch
.
load
(
args
.
model_path
)
actor
.
model
.
load_state_dict
(
state_dict
)
actor
.
model
.
load_state_dict
(
state_dict
)
# configure tokenizer
# configure tokenizer
if
args
.
model
==
'gpt2'
:
if
args
.
model
==
'gpt2'
:
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
...
@@ -49,7 +48,9 @@ def eval(args):
...
@@ -49,7 +48,9 @@ def eval(args):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
default
=
'gpt2'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
])
parser
.
add_argument
(
'--model'
,
default
=
'gpt2'
,
choices
=
[
'gpt2'
,
'bloom'
,
'opt'
])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--pretrain'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'Question: How are you ? Answer:'
)
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'Question: How are you ? Answer:'
)
parser
.
add_argument
(
'--max_length'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'--max_length'
,
type
=
int
,
default
=
100
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment