Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b1485e18
Unverified
Commit
b1485e18
authored
Mar 07, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 07, 2023
Browse files
fix(server): fix galactica batch (#106)
closes #105
parent
3fef90d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
5 deletions
+21
-5
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+21
-5
No files found.
server/text_generation_server/models/galactica.py
View file @
b1485e18
...
...
@@ -96,6 +96,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths
=
[]
# Parse batch
max_sequence_length
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
...
...
@@ -103,8 +105,13 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
)
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_sequence_length
=
max
(
max_sequence_length
,
r
.
input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
# Tokenize batch
...
...
@@ -114,6 +121,14 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding
=
True
,
return_token_type_ids
=
False
,
).
to
(
device
)
input_ids
=
tokenized_inputs
[
"input_ids"
]
# Allocate maximum attention_mask
attention_mask
=
input_ids
.
new_zeros
(
(
pb
.
size
,
max_sequence_length
+
padding_right_offset
)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask
[:,
:
max_sequence_length
]
=
tokenized_inputs
[
"attention_mask"
]
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
...
...
@@ -121,8 +136,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
tokenized_inputs
[
"
input_ids
"
]
,
attention_mask
=
tokenized_inputs
[
"
attention_mask
"
]
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
...
...
@@ -130,7 +145,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
max
(
input_lengths
),
max_sequence_length
=
max_sequence_length
,
padding_right_offset
=
padding_right_offset
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment