Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8ad60b75
Unverified
Commit
8ad60b75
authored
Mar 15, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 15, 2023
Browse files
fix(server): add position ids to neox (#126)
parent
cbd36aa4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
12 additions
and
32 deletions
+12
-32
server/Makefile
server/Makefile
+1
-1
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+2
-3
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+1
-3
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+1
-3
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+3
-18
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+1
-3
server/text_generation_server/utils/watermark.py
server/text_generation_server/utils/watermark.py
+3
-1
No files found.
server/Makefile
View file @
8ad60b75
transformers_commit
:=
517563354a3226ecfc3dca6e7a38012668d7156a
transformers_commit
:=
2b57aa18da658e7d2f42ef6bd5b56751af582fef
gen-server
:
gen-server
:
# Compile protos
# Compile protos
...
...
server/text_generation_server/models/__init__.py
View file @
8ad60b75
...
@@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
...
@@ -9,7 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from
text_generation_server.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation_server.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation_server.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation_server.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation_server.models.santacoder
import
SantaCoder
from
text_generation_server.models.santacoder
import
SantaCoder
from
text_generation_server.models.gpt_neox
import
GPTNeox
,
GPTNeoxSharded
from
text_generation_server.models.gpt_neox
import
GPTNeoxSharded
from
text_generation_server.models.t5
import
T5Sharded
from
text_generation_server.models.t5
import
T5Sharded
__all__
=
[
__all__
=
[
...
@@ -19,7 +19,6 @@ __all__ = [
...
@@ -19,7 +19,6 @@ __all__ = [
"CausalLM"
,
"CausalLM"
,
"Galactica"
,
"Galactica"
,
"GalacticaSharded"
,
"GalacticaSharded"
,
"GPTNeox"
,
"GPTNeoxSharded"
,
"GPTNeoxSharded"
,
"Seq2SeqLM"
,
"Seq2SeqLM"
,
"SantaCoder"
,
"SantaCoder"
,
...
@@ -62,7 +61,7 @@ def get_model(
...
@@ -62,7 +61,7 @@ def get_model(
if
sharded
:
if
sharded
:
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
return
GPTNeox
(
model_id
,
revision
,
quantize
=
quantize
)
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
)
if
config
.
model_type
==
"t5"
:
if
config
.
model_type
==
"t5"
:
if
sharded
:
if
sharded
:
...
...
server/text_generation_server/models/causal_lm.py
View file @
8ad60b75
...
@@ -72,9 +72,7 @@ class CausalLMBatch(Batch):
...
@@ -72,9 +72,7 @@ class CausalLMBatch(Batch):
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation_server/models/galactica.py
View file @
8ad60b75
...
@@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -102,9 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation_server/models/gpt_neox.py
View file @
8ad60b75
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
safetensors
import
safe_open
...
@@ -30,23 +30,7 @@ except Exception as e:
...
@@ -30,23 +30,7 @@ except Exception as e:
HAS_BITS_AND_BYTES
=
False
HAS_BITS_AND_BYTES
=
False
class
GPTNeox
(
CausalLM
):
class
GPTNeoxSharded
(
CausalLM
):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
return
outputs
.
logits
,
outputs
.
past_key_values
class
GPTNeoxSharded
(
GPTNeox
):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
):
...
@@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
...
@@ -224,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
use_cache
=
True
,
)
)
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
8ad60b75
...
@@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch):
...
@@ -82,9 +82,7 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation_server/utils/watermark.py
View file @
8ad60b75
...
@@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -43,7 +43,9 @@ class WatermarkLogitsProcessor(LogitsProcessor):
prev_token
=
input_ids
[
-
1
].
item
()
prev_token
=
input_ids
[
-
1
].
item
()
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
,
max_value
:
int
)
->
list
[
int
]:
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
,
max_value
:
int
)
->
list
[
int
]:
# seed the rng using the previous tokens/prefix
# seed the rng using the previous tokens/prefix
self
.
_seed_rng
(
input_ids
)
self
.
_seed_rng
(
input_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment