Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5fa8ae04
Unverified
Commit
5fa8ae04
authored
Apr 12, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 12, 2023
Browse files
feat(server): optimize decode for sane tokenizers (#170)
parent
6f0f1d70
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
67 additions
and
32 deletions
+67
-32
benchmark/Cargo.lock
benchmark/Cargo.lock
+2
-2
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+6
-2
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+8
-3
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+3
-3
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+1
-2
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+37
-14
server/text_generation_server/models/santacoder.py
server/text_generation_server/models/santacoder.py
+1
-2
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+9
-4
No files found.
benchmark/Cargo.lock
View file @
5fa8ae04
...
...
@@ -853,7 +853,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "grpc-metadata"
version = "0.
4.1
"
version = "0.
1.0
"
dependencies = [
"opentelemetry",
"tonic",
...
...
@@ -2140,7 +2140,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "0.
4.3
"
version = "0.
5.0
"
dependencies = [
"futures",
"grpc-metadata",
...
...
server/text_generation_server/models/bloom.py
View file @
5fa8ae04
...
...
@@ -49,6 +49,11 @@ class BloomCausalLMBatch(CausalLMBatch):
class
BLOOM
(
CausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
super
(
BLOOM
,
self
).
__init__
(
model_id
=
model_id
,
revision
=
revision
,
quantize
=
quantize
,
decode_buffer
=
1
)
@
property
def
batch_type
(
self
)
->
Type
[
CausalLMBatch
]:
return
BloomCausalLMBatch
...
...
@@ -94,8 +99,7 @@ class BLOOMSharded(BLOOM):
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
)
@
staticmethod
...
...
server/text_generation_server/models/causal_lm.py
View file @
5fa8ae04
...
...
@@ -291,7 +291,13 @@ class CausalLMBatch(Batch):
class
CausalLM
(
Model
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
,
decode_buffer
:
int
=
3
,
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -319,8 +325,7 @@ class CausalLM(Model):
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
decode_buffer
)
@
property
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
5fa8ae04
...
...
@@ -212,7 +212,8 @@ class FlashCausalLM(Model):
model_cls
:
Type
[
PreTrainedModel
],
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
,
quantize
:
bool
=
False
,
decode_buffer
:
int
=
3
,
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
...
...
@@ -237,8 +238,7 @@ class FlashCausalLM(Model):
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
decode_buffer
)
@
property
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
5fa8ae04
...
...
@@ -62,8 +62,7 @@ class FlashSantacoder(FlashCausalLM):
self
.
model
=
model
.
eval
().
to
(
device
).
to
(
dtype
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
)
@
staticmethod
...
...
server/text_generation_server/models/model.py
View file @
5fa8ae04
...
...
@@ -10,10 +10,19 @@ B = TypeVar("B", bound=Batch)
class
Model
(
ABC
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
decode_buffer
:
int
=
3
,
):
if
decode_buffer
<
1
:
raise
ValueError
(
"decode_buffer must be >= 1"
)
self
.
tokenizer
=
tokenizer
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
device
=
device
self
.
decode_buffer
=
decode_buffer
@
property
@
abstractmethod
...
...
@@ -39,23 +48,37 @@ class Model(ABC):
)
if
token_offset
is
None
:
token_offset
=
len
(
all_input_ids
)
-
3
# Decode token_offset token minus last one and token_offset tokens
results
=
self
.
tokenizer
.
batch_decode
(
[
all_input_ids
[
token_offset
:
-
1
],
all_input_ids
[
token_offset
:]],
skip_special_tokens
=
False
,
)
token_offset
=
len
(
all_input_ids
)
-
self
.
decode_buffer
# left token buffer
if
self
.
decode_buffer
>
1
:
# Decode token_offset token minus last one and token_offset tokens
raw_texts
=
self
.
tokenizer
.
batch_decode
(
[
all_input_ids
[
token_offset
:
-
1
],
all_input_ids
[
token_offset
:]],
skip_special_tokens
=
False
,
)
# default offset is only the last token
if
offset
is
None
:
offset
=
len
(
results
[
0
])
# default offset is only the last token
offset
=
len
(
raw_texts
[
0
])
sequence_text
=
raw_texts
[
1
]
else
:
# Only decode the last token without using a token buffer
sequence_text
=
self
.
tokenizer
.
decode
(
all_input_ids
[
-
1
],
skip_special_tokens
=
False
)
# no offset in this case
offset
=
0
else
:
assert
offset
is
not
None
sequence_text
=
self
.
tokenizer
.
decode
(
all_input_ids
[
token_offset
:],
skip_special_tokens
=
False
,
)
# get text
t
ext
=
results
[
1
]
[
offset
:]
t
oken_text
=
sequence_text
[
offset
:]
# if text is utf-8
if
text
and
text
[
-
1
]
!=
"�"
:
return
text
,
None
,
None
if
token_
text
and
token_
text
[
-
1
]
!=
"�"
:
return
token_
text
,
None
,
None
else
:
return
""
,
offset
,
token_offset
server/text_generation_server/models/santacoder.py
View file @
5fa8ae04
...
...
@@ -54,8 +54,7 @@ class SantaCoder(CausalLM):
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
)
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
5fa8ae04
...
...
@@ -330,7 +330,13 @@ class Seq2SeqLMBatch(Batch):
class
Seq2SeqLM
(
Model
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
,
decode_buffer
:
int
=
3
,
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -354,8 +360,7 @@ class Seq2SeqLM(Model):
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
decode_buffer
)
@
property
...
...
@@ -496,7 +501,7 @@ class Seq2SeqLM(Model):
if
stop
:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text
=
self
.
decode
(
decoder_input_ids
[
-
new_
decoder_input_length
:])
output_text
=
self
.
decode
(
decoder_input_ids
[
-
decoder_input_length
:])
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment