Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
462530c2
Unverified
Commit
462530c2
authored
Mar 27, 2023
by
Nick Hill
Committed by
GitHub
Mar 27, 2023
Browse files
fix(server): Avoid using try/except to determine kind of AutoModel (#142)
parent
ab5fd8cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
5 deletions
+10
-5
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+10
-5
No files found.
server/text_generation_server/models/__init__.py
View file @
462530c2
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
loguru
import
logger
from
loguru
import
logger
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
transformers.models.auto
import
modeling_auto
from
typing
import
Optional
from
typing
import
Optional
from
text_generation_server.models.model
import
Model
from
text_generation_server.models.model
import
Model
...
@@ -65,14 +66,15 @@ def get_model(
...
@@ -65,14 +66,15 @@ def get_model(
return
SantaCoder
(
model_id
,
revision
,
quantize
)
return
SantaCoder
(
model_id
,
revision
,
quantize
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
model_type
=
config
.
model_type
if
config
.
model_type
==
"bloom"
:
if
model_type
==
"bloom"
:
if
sharded
:
if
sharded
:
return
BLOOMSharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
BLOOMSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
return
BLOOM
(
model_id
,
revision
,
quantize
=
quantize
)
return
BLOOM
(
model_id
,
revision
,
quantize
=
quantize
)
if
config
.
model_type
==
"gpt_neox"
:
if
model_type
==
"gpt_neox"
:
if
sharded
:
if
sharded
:
neox_cls
=
FlashNeoXSharded
if
FLASH_NEOX
else
GPTNeoxSharded
neox_cls
=
FlashNeoXSharded
if
FLASH_NEOX
else
GPTNeoxSharded
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
...
@@ -80,7 +82,7 @@ def get_model(
...
@@ -80,7 +82,7 @@ def get_model(
neox_cls
=
FlashNeoX
if
FLASH_NEOX
else
CausalLM
neox_cls
=
FlashNeoX
if
FLASH_NEOX
else
CausalLM
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
if
config
.
model_type
==
"t5"
:
if
model_type
==
"t5"
:
if
sharded
:
if
sharded
:
return
T5Sharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
T5Sharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
...
@@ -88,7 +90,10 @@ def get_model(
...
@@ -88,7 +90,10 @@ def get_model(
if
sharded
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
raise
ValueError
(
"sharded is not supported for AutoModel"
)
try
:
if
model_type
in
modeling_auto
.
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
)
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
)
except
Exception
:
if
model_type
in
modeling_auto
.
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
:
return
Seq2SeqLM
(
model_id
,
revision
,
quantize
=
quantize
)
return
Seq2SeqLM
(
model_id
,
revision
,
quantize
=
quantize
)
raise
ValueError
(
f
"Unsupported model type
{
model_type
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment