Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f0000689
Unverified
Commit
f0000689
authored
Mar 28, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 28, 2023
Browse files
feat(server): clear cache on error (#143)
parent
8e8dd984
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
21 additions
and
9 deletions
+21
-9
proto/generate.proto
proto/generate.proto
+4
-2
router/client/src/client.rs
router/client/src/client.rs
+2
-2
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+2
-2
router/src/infer.rs
router/src/infer.rs
+2
-0
router/src/main.rs
router/src/main.rs
+1
-1
server/text_generation_server/cache.py
server/text_generation_server/cache.py
+3
-1
server/text_generation_server/pb/.gitignore
server/text_generation_server/pb/.gitignore
+1
-0
server/text_generation_server/server.py
server/text_generation_server/server.py
+6
-1
No files found.
proto/generate.proto
View file @
f0000689
...
...
@@ -21,8 +21,10 @@ message ServiceDiscoveryResponse {
repeated
string
urls
=
1
;
}
/// Empty request
message
ClearCacheRequest
{}
message
ClearCacheRequest
{
/// Optional batch id
optional
uint64
id
=
1
;
}
/// Empty response
message
ClearCacheResponse
{}
...
...
router/client/src/client.rs
View file @
f0000689
...
...
@@ -56,8 +56,8 @@ impl Client {
/// Clear the past generations cache
#[instrument(skip(self))]
pub
async
fn
clear_cache
(
&
mut
self
)
->
Result
<
()
>
{
let
request
=
tonic
::
Request
::
new
(
ClearCacheRequest
{})
.inject_context
();
pub
async
fn
clear_cache
(
&
mut
self
,
batch_id
:
Option
<
u64
>
)
->
Result
<
()
>
{
let
request
=
tonic
::
Request
::
new
(
ClearCacheRequest
{
id
:
batch_id
})
.inject_context
();
self
.stub
.clear_cache
(
request
)
.await
?
;
Ok
(())
}
...
...
router/client/src/sharded_client.rs
View file @
f0000689
...
...
@@ -40,11 +40,11 @@ impl ShardedClient {
/// Clear the past generations cache
#[instrument(skip(self))]
pub
async
fn
clear_cache
(
&
mut
self
)
->
Result
<
()
>
{
pub
async
fn
clear_cache
(
&
mut
self
,
batch_id
:
Option
<
u64
>
)
->
Result
<
()
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
client
.clear_cache
())
.map
(|
client
|
client
.clear_cache
(
batch_id
))
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
()
}
...
...
router/src/infer.rs
View file @
f0000689
...
...
@@ -330,6 +330,7 @@ async fn prefill(
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
)
->
Option
<
Batch
>
{
let
start_time
=
Instant
::
now
();
let
batch_id
=
batch
.id
;
match
client
.prefill
(
batch
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
...
...
@@ -340,6 +341,7 @@ async fn prefill(
}
// If we have an error, we discard the whole batch
Err
(
err
)
=>
{
let
_
=
client
.clear_cache
(
Some
(
batch_id
))
.await
;
send_errors
(
err
,
entries
);
metrics
::
increment_counter!
(
"tgi_batch_inference_failure"
,
"method"
=>
"prefill"
);
None
...
...
router/src/main.rs
View file @
f0000689
...
...
@@ -136,7 +136,7 @@ fn main() -> Result<(), std::io::Error> {
.expect
(
"Could not connect to server"
);
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache
()
.clear_cache
(
None
)
.await
.expect
(
"Unable to clear cache"
);
tracing
::
info!
(
"Connected"
);
...
...
server/text_generation_server/cache.py
View file @
f0000689
...
...
@@ -17,7 +17,9 @@ class Cache:
self
.
cache
[
entry
.
batch_id
]
=
entry
def
delete
(
self
,
batch_id
:
int
):
del
self
.
cache
[
batch_id
]
batch
=
self
.
pop
(
batch_id
)
if
batch
is
not
None
:
del
batch
def
clear
(
self
):
self
.
cache
.
clear
()
...
...
server/text_generation_server/pb/.gitignore
View file @
f0000689
*.py
*.pyi
*.py-e
\ No newline at end of file
server/text_generation_server/server.py
View file @
f0000689
...
...
@@ -30,7 +30,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return
generate_pb2
.
ServiceDiscoveryResponse
(
urls
=
self
.
server_urls
)
async
def
ClearCache
(
self
,
request
,
context
):
self
.
cache
.
clear
()
if
request
.
HasField
(
"id"
):
self
.
cache
.
delete
(
request
.
id
)
else
:
self
.
cache
.
clear
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
Prefill
(
self
,
request
,
context
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment