Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
042180d8
Commit
042180d8
authored
Dec 08, 2022
by
OlivierDehaene
Browse files
fix(server): Only pad to multiple of 8 on GPUs
parent
a2985036
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
2 deletions
+4
-2
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+2
-1
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+2
-1
No files found.
server/text_generation/models/causal_lm.py
View file @
042180d8
...
@@ -71,8 +71,9 @@ class CausalLMBatch:
...
@@ -71,8 +71,9 @@ class CausalLMBatch:
)
)
)
)
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
tokenized_inputs
=
tokenizer
(
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
).
to
(
device
)
).
to
(
device
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
...
...
server/text_generation/models/seq2seq_lm.py
View file @
042180d8
...
@@ -83,8 +83,9 @@ class Seq2SeqLMBatch:
...
@@ -83,8 +83,9 @@ class Seq2SeqLMBatch:
)
)
# Tokenize batch
# Tokenize batch
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
tokenized_inputs
=
tokenizer
(
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
).
to
(
device
)
).
to
(
device
)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
,
device
=
device
).
unsqueeze
(
-
1
)
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
,
device
=
device
).
unsqueeze
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment