Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
6abec14a
Unverified
Commit
6abec14a
authored
Jun 05, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 05, 2023
Browse files
feat(server): batch tokenization for flash causal lm (#411)
parent
895c5f15
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
5 deletions
+17
-5
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+14
-4
server/text_generation_server/utils/hub.py
server/text_generation_server/utils/hub.py
+3
-1
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
6abec14a
...
@@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch):
...
@@ -79,6 +79,16 @@ class FlashCausalLMBatch(Batch):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
)
->
"FlashCausalLMBatch"
:
batch_inputs
=
[]
max_truncation
=
0
for
r
in
pb
.
requests
:
batch_inputs
.
append
(
r
.
inputs
)
max_truncation
=
max
(
max_truncation
,
r
.
truncate
)
batch_tokenized_inputs
=
tokenizer
(
batch_inputs
,
truncation
=
True
,
max_length
=
max_truncation
)[
"input_ids"
]
position_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
0
]
cu_seqlens
=
[
0
]
max_seqlen
=
0
max_seqlen
=
0
...
@@ -106,13 +116,13 @@ class FlashCausalLMBatch(Batch):
...
@@ -106,13 +116,13 @@ class FlashCausalLMBatch(Batch):
max_length
=
0
max_length
=
0
# Parse batch
# Parse batch
for
i
,
r
in
enumerate
(
pb
.
requests
):
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
):
# request id -> idx in list mapping
# request id -> idx in list mapping
requests_idx_mapping
[
r
.
id
]
=
i
requests_idx_mapping
[
r
.
id
]
=
i
tokenized_input
=
tokenizer
(
tokenized_input
=
tokenized_input
[
-
r
.
truncate
:]
r
.
inputs
,
truncation
=
True
,
max_length
=
r
.
truncate
)[
"input_ids"
]
input_length
=
len
(
tokenized_input
)
input_length
=
len
(
tokenized_input
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
...
...
server/text_generation_server/utils/hub.py
View file @
6abec14a
...
@@ -134,7 +134,7 @@ def download_weights(
...
@@ -134,7 +134,7 @@ def download_weights(
)
->
List
[
Path
]:
)
->
List
[
Path
]:
"""Download the safetensors files from the hub"""
"""Download the safetensors files from the hub"""
def
download_file
(
filename
,
tries
=
5
):
def
download_file
(
filename
,
tries
=
5
,
backoff
:
int
=
5
):
local_file
=
try_to_load_from_cache
(
model_id
,
revision
,
filename
)
local_file
=
try_to_load_from_cache
(
model_id
,
revision
,
filename
)
if
local_file
is
not
None
:
if
local_file
is
not
None
:
logger
.
info
(
f
"File
{
filename
}
already present in cache."
)
logger
.
info
(
f
"File
{
filename
}
already present in cache."
)
...
@@ -158,6 +158,8 @@ def download_weights(
...
@@ -158,6 +158,8 @@ def download_weights(
if
i
+
1
==
tries
:
if
i
+
1
==
tries
:
raise
e
raise
e
logger
.
error
(
e
)
logger
.
error
(
e
)
logger
.
info
(
f
"Retrying in
{
backoff
}
seconds"
)
time
.
sleep
(
backoff
)
logger
.
info
(
f
"Retry
{
i
+
1
}
/
{
tries
-
1
}
"
)
logger
.
info
(
f
"Retry
{
i
+
1
}
/
{
tries
-
1
}
"
)
# We do this instead of using tqdm because we want to parse the logs with the launcher
# We do this instead of using tqdm because we want to parse the logs with the launcher
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment