Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
fd5b8c88
Commit
fd5b8c88
authored
Aug 19, 2023
by
Casper Hansen
Browse files
FP16 weights example works
parent
0834fb46
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
7 deletions
+10
-7
tinychat/utils/prompt_templates.py
tinychat/utils/prompt_templates.py
+10
-7
No files found.
tinychat/utils/prompt_templates.py
View file @
fd5b8c88
from
typing
import
List
from
typing
import
List
from
awq.models
import
*
from
awq.models
import
*
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
transformers.models.falcon.modeling_falcon
import
FalconForCausalLM
class
BasePrompter
:
class
BasePrompter
:
def
__init__
(
self
,
system_inst
,
role1
,
role2
,
sen_spliter
=
"
\n
"
,
qa_spliter
=
"
\n
"
,
decorator
:
List
[
str
]
=
None
):
def
__init__
(
self
,
system_inst
,
role1
,
role2
,
sen_spliter
=
"
\n
"
,
qa_spliter
=
"
\n
"
,
decorator
:
List
[
str
]
=
None
):
...
@@ -127,14 +129,14 @@ class MPTChatPrompter(BasePrompter):
...
@@ -127,14 +129,14 @@ class MPTChatPrompter(BasePrompter):
def
get_prompter
(
model
,
model_path
=
""
):
def
get_prompter
(
model
,
model_path
=
""
):
if
isinstance
(
model
,
LlamaAWQForCausalLM
):
if
isinstance
(
model
,
LlamaAWQForCausalLM
)
or
isinstance
(
model
,
LlamaForCausalLM
)
:
if
"vicuna"
in
model_path
:
if
"vicuna"
in
model_path
:
return
VicunaPrompter
()
return
VicunaPrompter
()
else
:
else
:
return
Llama2Prompter
()
return
Llama2Prompter
()
elif
isinstance
(
model
,
FalconAWQForCausalLM
):
elif
isinstance
(
model
,
FalconAWQForCausalLM
)
or
isinstance
(
model
,
FalconForCausalLM
)
:
return
FalconSimplePrompter
()
return
FalconSimplePrompter
()
elif
isinstance
(
model
,
MptAWQForCausalLM
):
elif
isinstance
(
model
,
MptAWQForCausalLM
)
or
"mpt"
in
str
(
model
.
__class__
).
lower
()
:
if
"mpt"
and
"chat"
in
model_path
:
if
"mpt"
and
"chat"
in
model_path
:
return
MPTChatPrompter
()
return
MPTChatPrompter
()
else
:
else
:
...
@@ -143,14 +145,15 @@ def get_prompter(model, model_path = ""):
...
@@ -143,14 +145,15 @@ def get_prompter(model, model_path = ""):
raise
ValueError
(
f
"model type
{
model
.
model_type
}
is not supported"
)
raise
ValueError
(
f
"model type
{
model
.
model_type
}
is not supported"
)
def
get_stop_token_ids
(
model
,
model_path
=
""
):
def
get_stop_token_ids
(
model
,
model_path
=
""
):
if
isinstance
(
model
,
LlamaAWQForCausalLM
):
if
isinstance
(
model
,
LlamaAWQForCausalLM
)
or
isinstance
(
model
,
LlamaForCausalLM
)
:
return
[]
return
[]
elif
isinstance
(
model
,
FalconAWQForCausalLM
):
elif
isinstance
(
model
,
FalconAWQForCausalLM
)
or
isinstance
(
model
,
FalconForCausalLM
)
:
return
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
]
return
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
]
elif
isinstance
(
model
,
MptAWQForCausalLM
):
elif
isinstance
(
model
,
MptAWQForCausalLM
)
or
"mpt"
in
str
(
model
.
__class__
).
lower
()
:
if
"mpt"
and
"chat"
in
model_path
:
if
"mpt"
and
"chat"
in
model_path
:
return
[
50278
,
0
]
return
[
50278
,
0
]
else
:
else
:
return
[]
return
[]
else
:
else
:
raise
ValueError
(
f
"model type
{
model
.
model_type
}
is not supported"
)
model_type
=
str
(
model
.
__class__
).
lower
()
raise
ValueError
(
f
"model type
{
model_type
}
is not supported"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment