Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
99879600
"csrc/ktransformers_ext/ext_bindings.cpp" did not exist on "21fca5a326097de6629098ede47357b899868010"
Unverified
Commit
99879600
authored
Apr 09, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 09, 2023
Browse files
feat(router): make router input validation optional (#164)
parent
7dec65a2
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
42 deletions
+32
-42
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+4
-32
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+9
-1
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+12
-6
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+1
-1
server/text_generation_server/models/santacoder.py
server/text_generation_server/models/santacoder.py
+1
-1
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+4
-0
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+1
-1
No files found.
server/text_generation_server/models/flash_neox.py
View file @
99879600
...
...
@@ -45,18 +45,19 @@ class FlashNeoXSharded(FlashNeoX):
raise
NotImplementedError
(
"FlashNeoX does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
tp_parallel
=
True
model_id
,
revision
=
revision
,
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
with
init_empty_weights
():
model
=
FlashGPTNeoXForCausalLM
(
config
)
model
=
FlashGPTNeoXForCausalLM
(
config
,
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
...
...
@@ -147,32 +148,3 @@ class FlashNeoXSharded(FlashNeoX):
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_s
:
int
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
model
.
gpt_neox
.
tp_embeddings
:
logits
,
present
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
)
# Logits are sharded, so we need to gather them
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
return
world_logits
,
present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else
:
return
super
(
FlashNeoXSharded
,
self
).
forward
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
server/text_generation_server/models/flash_santacoder.py
View file @
99879600
...
...
@@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM):
raise
NotImplementedError
(
"FlashSantacoder does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
...
...
@@ -56,6 +56,8 @@ class FlashSantacoder(FlashCausalLM):
self
.
load_weights
(
model
,
filenames
,
device
,
dtype
,
)
self
.
model
=
model
.
eval
().
to
(
device
).
to
(
dtype
)
...
...
@@ -68,10 +70,14 @@ class FlashSantacoder(FlashCausalLM):
def
load_weights
(
model
:
FlashSantacoderForCausalLM
,
filenames
:
List
[
Path
],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
):
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
for
key
,
value
in
state_dict
.
items
():
value
=
value
.
to
(
device
).
to
(
dtype
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
# Fused qkv
...
...
@@ -141,6 +147,8 @@ class FlashSantacoder(FlashCausalLM):
else
:
module
.
_buffers
[
param_name
]
=
value
del
value
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
...
...
server/text_generation_server/models/galactica.py
View file @
99879600
...
...
@@ -96,7 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths
=
[]
# Parse batch
max_
sequence_length
=
0
max_
truncation
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
# Add escape_custom_split_sequence to the CausalLMBatch logic
...
...
@@ -107,7 +107,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_
sequence_length
=
max
(
max_sequence_length
,
r
.
input_length
)
max_
truncation
=
max
(
max_truncation
,
r
.
truncate
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -118,14 +118,20 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return_tensors
=
"pt"
,
padding
=
True
,
return_token_type_ids
=
False
,
truncation
=
True
,
max_length
=
max_truncation
,
).
to
(
device
)
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
max_input_length
=
input_lengths
.
max
()
input_ids
=
tokenized_inputs
[
"input_ids"
]
# Allocate maximum attention_mask
attention_mask
=
input_ids
.
new_zeros
(
(
pb
.
size
,
max_
sequence
_length
+
padding_right_offset
)
(
pb
.
size
,
max_
input
_length
+
padding_right_offset
)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask
[:,
:
max_
sequence
_length
]
=
tokenized_inputs
[
"attention_mask"
]
attention_mask
[:,
:
max_
input
_length
]
=
tokenized_inputs
[
"attention_mask"
]
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
...
...
@@ -143,7 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_
sequence
_length
=
max_
sequence
_length
,
max_
input
_length
=
max_
input
_length
,
padding_right_offset
=
padding_right_offset
,
)
...
...
@@ -188,7 +194,7 @@ class GalacticaSharded(Galactica):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
...
...
server/text_generation_server/models/gpt_neox.py
View file @
99879600
...
...
@@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
...
...
server/text_generation_server/models/santacoder.py
View file @
99879600
...
...
@@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
tokenizer
.
add_special_tokens
(
{
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
99879600
...
...
@@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_lengths
=
[]
# Parse batch
max_truncation
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
...
...
@@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_truncation
=
max
(
max_truncation
,
r
.
truncate
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch):
return_tensors
=
"pt"
,
padding
=
True
,
return_token_type_ids
=
False
,
truncation
=
True
,
max_length
=
max_truncation
,
).
to
(
device
)
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
...
...
server/text_generation_server/models/t5.py
View file @
99879600
...
...
@@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment