Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
53aa9194
Unverified
Commit
53aa9194
authored
Jun 20, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 20, 2023
Browse files
fix(server): fix warpers on CPU (#472)
Closes #471
parent
ece7ffa4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
33 deletions
+31
-33
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+6
-14
server/text_generation_server/utils/logits_process.py
server/text_generation_server/utils/logits_process.py
+25
-19
No files found.
server/text_generation_server/models/__init__.py
View file @
53aa9194
...
@@ -237,20 +237,12 @@ def get_model(
...
@@ -237,20 +237,12 @@ def get_model(
)
)
elif
model_type
==
"t5"
:
elif
model_type
==
"t5"
:
if
sharded
:
return
T5Sharded
(
return
T5Sharded
(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
else
:
return
Seq2SeqLM
(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
,
)
if
sharded
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
raise
ValueError
(
"sharded is not supported for AutoModel"
)
...
...
server/text_generation_server/utils/logits_process.py
View file @
53aa9194
...
@@ -42,25 +42,31 @@ class StaticWarper:
...
@@ -42,25 +42,31 @@ class StaticWarper:
self
.
static_next_logprob
=
None
self
.
static_next_logprob
=
None
def
__call__
(
self
,
scores
):
def
__call__
(
self
,
scores
):
if
self
.
cuda_graph
is
None
:
if
torch
.
cuda
.
is_available
():
self
.
static_scores
=
scores
if
self
.
cuda_graph
is
None
:
self
.
cuda_graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
static_scores
=
scores
self
.
cuda_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
cuda_graph
,
pool
=
mempool
):
local_scores
=
self
.
static_scores
with
torch
.
cuda
.
graph
(
self
.
cuda_graph
,
pool
=
mempool
):
for
warper
in
self
.
warpers
:
local_scores
=
self
.
static_scores
local_scores
=
warper
(
None
,
local_scores
)
for
warper
in
self
.
warpers
:
local_scores
=
warper
(
None
,
local_scores
)
self
.
static_warped_scores
=
local_scores
# Compute logprobs
self
.
static_warped_scores
=
local_scores
self
.
static_next_logprob
=
torch
.
log_softmax
(
# Compute logprobs
self
.
static_warped_scores
,
-
1
self
.
static_next_logprob
=
torch
.
log_softmax
(
)
self
.
static_warped_scores
,
-
1
)
self
.
static_scores
.
copy_
(
scores
)
self
.
cuda_graph
.
replay
()
self
.
static_scores
.
copy_
(
scores
)
self
.
cuda_graph
.
replay
()
return
self
.
static_warped_scores
,
self
.
static_next_logprob
return
self
.
static_warped_scores
,
self
.
static_next_logprob
# CPU branch
for
warper
in
self
.
warpers
:
scores
=
warper
(
None
,
scores
)
return
scores
,
torch
.
log_softmax
(
scores
,
-
1
)
@
lru_cache
(
10
)
@
lru_cache
(
10
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment