Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
78063c05
Unverified
Commit
78063c05
authored
Feb 20, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 20, 2023
Browse files
fix(server): remove position_ids from galactica forward (#82)
closes #80
parent
17bc841b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
2 deletions
+15
-2
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+15
-2
No files found.
server/text_generation/models/galactica.py
View file @
78063c05
...
@@ -2,7 +2,7 @@ import re
...
@@ -2,7 +2,7 @@ import re
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
typing
import
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
,
Tuple
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
safetensors
import
safe_open
...
@@ -145,6 +145,20 @@ class Galactica(CausalLM):
...
@@ -145,6 +145,20 @@ class Galactica(CausalLM):
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
)
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
return
outputs
.
logits
,
outputs
.
past_key_values
class
GalacticaSharded
(
Galactica
):
class
GalacticaSharded
(
Galactica
):
def
__init__
(
def
__init__
(
...
@@ -322,7 +336,6 @@ class GalacticaSharded(Galactica):
...
@@ -322,7 +336,6 @@ class GalacticaSharded(Galactica):
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
use_cache
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment