Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f2f0289f
Commit
f2f0289f
authored
Jul 12, 2023
by
OlivierDehaene
Browse files
feat(server): empty cache on errors
parent
67347950
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
3 deletions
+13
-3
server/text_generation_server/interceptor.py
server/text_generation_server/interceptor.py
+4
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+0
-3
server/text_generation_server/server.py
server/text_generation_server/server.py
+9
-0
No files found.
server/text_generation_server/interceptor.py
View file @
f2f0289f
import
torch
import
grpc
from
google.rpc
import
status_pb2
,
code_pb2
...
...
@@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor):
method_name
=
method_name
.
split
(
"/"
)[
-
1
]
logger
.
exception
(
f
"Method
{
method_name
}
encountered an error."
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
await
context
.
abort_with_status
(
rpc_status
.
to_status
(
status_pb2
.
Status
(
code
=
code_pb2
.
INTERNAL
,
message
=
str
(
err
))
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
f2f0289f
...
...
@@ -639,7 +639,6 @@ class FlashCausalLMBatch(Batch):
for
b
in
batches
:
b
.
block_tables
=
None
del
b
torch
.
cuda
.
empty_cache
()
return
FlashCausalLMBatch
(
batch_id
=
batches
[
0
].
batch_id
,
...
...
@@ -733,7 +732,6 @@ class FlashCausalLM(Model):
f
"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
)
from
e
del
batch
torch
.
cuda
.
empty_cache
()
def
decode
(
self
,
generated_ids
:
Union
[
torch
.
Tensor
,
List
[
int
]])
->
str
:
return
self
.
tokenizer
.
decode
(
...
...
@@ -790,7 +788,6 @@ class FlashCausalLM(Model):
)
except
Exception
as
e
:
del
batch
torch
.
cuda
.
empty_cache
()
raise
e
if
prefill
:
...
...
server/text_generation_server/server.py
View file @
f2f0289f
...
...
@@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch
=
batch
.
filter
(
request
.
request_ids
)
self
.
cache
.
set
(
filtered_batch
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
async
def
Warmup
(
self
,
request
,
context
):
...
...
@@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
dtype
,
self
.
model
.
device
)
self
.
model
.
warmup
(
batch
,
request
.
max_total_tokens
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
return
generate_pb2
.
WarmupResponse
()
async
def
Prefill
(
self
,
request
,
context
):
...
...
@@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if
len
(
batches
)
>
1
:
batch
=
self
.
model
.
batch_type
.
concatenate
(
batches
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
else
:
batch
=
batches
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment