Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
37df6df3
Unverified
Commit
37df6df3
authored
Jul 24, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 24, 2023
Browse files
fix(server): fix exllama buffers (#689)
Close #683
parent
73a4d65d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
13 deletions
+22
-13
server/text_generation_server/server.py
server/text_generation_server/server.py
+15
-13
server/text_generation_server/utils/gptq/exllama.py
server/text_generation_server/utils/gptq/exllama.py
+7
-0
No files found.
server/text_generation_server/server.py
View file @
37df6df3
...
...
@@ -105,21 +105,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def
serve
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
Optional
[
str
],
dtype
:
Optional
[
str
],
trust_remote_code
:
bool
,
uds_path
:
Path
,
):
async
def
serve_inner
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
=
False
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
,
sharded
:
bool
,
quantize
:
Optional
[
str
],
dtype
:
Optional
[
str
],
trust_remote_code
:
bool
,
uds_path
:
Path
,
):
async
def
serve_inner
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
=
False
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
unix_socket_template
=
"unix://{}-{}"
if
sharded
:
...
...
@@ -147,8 +147,10 @@ def serve(
# This will allocate those buffers.
from
text_generation_server.utils.gptq.exllama
import
(
create_exllama_buffers
,
set_device
,
)
set_device
(
model
.
device
)
create_exllama_buffers
()
except
ImportError
:
pass
...
...
server/text_generation_server/utils/gptq/exllama.py
View file @
37df6df3
...
...
@@ -32,9 +32,16 @@ TEMP_STATE = None
TEMP_DQ
=
None
def
set_device
(
device
):
global
DEVICE
DEVICE
=
device
def
create_exllama_buffers
():
global
MAX_DQ
,
MAX_INNER
,
ACT_ORDER
,
DEVICE
,
TEMP_STATE
,
TEMP_DQ
assert
DEVICE
is
not
None
,
"call set_device first"
if
ACT_ORDER
:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment