Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c0928e6f
Unverified
Commit
c0928e6f
authored
Jun 01, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 01, 2023
Browse files
feat(server): remove trust_remote_code requirement for falcon models (#396)
parent
d69a0633
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
+9
-11
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+9
-11
No files found.
server/text_generation_server/models/__init__.py
View file @
c0928e6f
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
transformers
import
Auto
Config
from
transformers
.configuration_utils
import
Pretrained
Config
from
transformers.models.auto
import
modeling_auto
from
transformers.models.auto
import
modeling_auto
from
typing
import
Optional
from
typing
import
Optional
...
@@ -138,10 +138,8 @@ def get_model(
...
@@ -138,10 +138,8 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
config
=
AutoConfig
.
from_pretrained
(
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
model_type
=
config_dict
[
"model_type"
]
)
model_type
=
config
.
model_type
if
model_type
==
"gpt_bigcode"
:
if
model_type
==
"gpt_bigcode"
:
if
sharded
:
if
sharded
:
...
@@ -201,9 +199,9 @@ def get_model(
...
@@ -201,9 +199,9 @@ def get_model(
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
]:
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
]:
if
sharded
:
if
sharded
:
if
FLASH_ATTENTION
:
if
FLASH_ATTENTION
:
if
config
.
alibi
or
(
if
config
_dict
.
get
(
"alibi"
,
False
)
or
(
config
.
model_type
==
"RefinedWebModel"
model_type
==
"RefinedWebModel"
and
config
.
n_head_kv
!=
config
.
n_head
and
config
_dict
.
get
(
"multi_query"
,
True
)
):
):
raise
NotImplementedError
(
"sharded is not supported for this model"
)
raise
NotImplementedError
(
"sharded is not supported for this model"
)
return
FlashRWSharded
(
return
FlashRWSharded
(
...
@@ -216,7 +214,7 @@ def get_model(
...
@@ -216,7 +214,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded RefinedWeb"
)
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded RefinedWeb"
)
)
)
else
:
else
:
if
FLASH_ATTENTION
and
not
config
.
alibi
:
if
FLASH_ATTENTION
and
not
config
_dict
.
get
(
"alibi"
,
False
)
:
return
FlashRW
(
return
FlashRW
(
model_id
,
model_id
,
revision
,
revision
,
...
@@ -250,7 +248,7 @@ def get_model(
...
@@ -250,7 +248,7 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
config
.
model_type
==
"opt"
:
if
model_type
==
"opt"
:
if
sharded
:
if
sharded
:
return
OPTSharded
(
return
OPTSharded
(
model_id
,
model_id
,
...
@@ -294,7 +292,7 @@ def get_model(
...
@@ -294,7 +292,7 @@ def get_model(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
)
)
auto_map
=
getattr
(
config
,
"auto_map"
,
None
)
auto_map
=
config_dict
.
get
(
"auto_map"
,
None
)
if
trust_remote_code
and
auto_map
is
not
None
:
if
trust_remote_code
and
auto_map
is
not
None
:
if
"AutoModelForCausalLM"
in
auto_map
.
keys
():
if
"AutoModelForCausalLM"
in
auto_map
.
keys
():
return
CausalLM
(
return
CausalLM
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment