Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1f570d18
Unverified
Commit
1f570d18
authored
Jan 20, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 20, 2023
Browse files
fix(server): Fix position ids (#28)
parent
15511edc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
33 additions
and
36 deletions
+33
-36
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+2
-0
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+2
-1
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+22
-3
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+5
-1
server/text_generation/models/santacoder.py
server/text_generation/models/santacoder.py
+1
-30
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+1
-1
No files found.
server/tests/models/test_santacoder.py
View file @
1f570d18
...
@@ -42,6 +42,7 @@ def default_fim_pb_batch(default_fim_pb_request):
...
@@ -42,6 +42,7 @@ def default_fim_pb_batch(default_fim_pb_request):
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_fim_pb_request
],
size
=
1
)
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_fim_pb_request
],
size
=
1
)
@
pytest
.
mark
.
skip
def
test_santacoder_generate_token_completion
(
default_santacoder
,
default_pb_batch
):
def
test_santacoder_generate_token_completion
(
default_santacoder
,
default_pb_batch
):
batch
=
CausalLMBatch
.
from_pb
(
batch
=
CausalLMBatch
.
from_pb
(
default_pb_batch
,
default_santacoder
.
tokenizer
,
default_santacoder
.
device
default_pb_batch
,
default_santacoder
.
tokenizer
,
default_santacoder
.
device
...
@@ -65,6 +66,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
...
@@ -65,6 +66,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
)
)
@
pytest
.
mark
.
skip
def
test_fim_santacoder_generate_token_completion
(
def
test_fim_santacoder_generate_token_completion
(
default_santacoder
,
default_fim_pb_batch
default_santacoder
,
default_fim_pb_batch
):
):
...
...
server/text_generation/models/bloom.py
View file @
1f570d18
...
@@ -236,10 +236,11 @@ class BLOOMSharded(BLOOM):
...
@@ -236,10 +236,11 @@ class BLOOMSharded(BLOOM):
if
name
==
"word_embeddings.weight"
:
if
name
==
"word_embeddings.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
use_cache
=
True
,
)
)
...
...
server/text_generation/models/causal_lm.py
View file @
1f570d18
...
@@ -18,6 +18,7 @@ class CausalLMBatch(Batch):
...
@@ -18,6 +18,7 @@ class CausalLMBatch(Batch):
# Decoder values
# Decoder values
input_ids
:
torch
.
Tensor
input_ids
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
past_key_values
:
Optional
[
List
[
Tuple
]]
past_key_values
:
Optional
[
List
[
Tuple
]]
# All tokens
# All tokens
...
@@ -76,6 +77,8 @@ class CausalLMBatch(Batch):
...
@@ -76,6 +77,8 @@ class CausalLMBatch(Batch):
pad_to_multiple_of
=
pad_to_multiple_of
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
return_token_type_ids
=
False
,
).
to
(
device
)
).
to
(
device
)
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
return
cls
(
...
@@ -83,6 +86,7 @@ class CausalLMBatch(Batch):
...
@@ -83,6 +86,7 @@ class CausalLMBatch(Batch):
requests
=
pb
.
requests
,
requests
=
pb
.
requests
,
input_ids
=
tokenized_inputs
[
"input_ids"
],
input_ids
=
tokenized_inputs
[
"input_ids"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
position_ids
=
position_ids
,
past_key_values
=
None
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
all_logprobs
=
all_logprobs
,
all_logprobs
=
all_logprobs
,
...
@@ -110,6 +114,7 @@ class CausalLMBatch(Batch):
...
@@ -110,6 +114,7 @@ class CausalLMBatch(Batch):
# Batch tensors
# Batch tensors
input_ids
=
None
input_ids
=
None
attention_mask
=
None
attention_mask
=
None
position_ids
=
None
past_key_values
=
[]
past_key_values
=
[]
# Used for slicing correctly inside the tensors
# Used for slicing correctly inside the tensors
...
@@ -149,6 +154,12 @@ class CausalLMBatch(Batch):
...
@@ -149,6 +154,12 @@ class CausalLMBatch(Batch):
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
attention_mask
[:,
-
batch
.
max_sequence_length
:]
]
=
batch
.
attention_mask
[:,
-
batch
.
max_sequence_length
:]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if
position_ids
is
None
:
position_ids
=
batch
.
position_ids
.
new_empty
((
total_batch_size
,
1
))
position_ids
[
start_index
:
end_index
]
=
batch
.
position_ids
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
past_keys
,
past_values
=
past
past_keys
,
past_values
=
past
...
@@ -211,6 +222,7 @@ class CausalLMBatch(Batch):
...
@@ -211,6 +222,7 @@ class CausalLMBatch(Batch):
requests
=
requests
,
requests
=
requests
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
all_logprobs
=
all_logprobs
,
all_logprobs
=
all_logprobs
,
...
@@ -263,12 +275,13 @@ class CausalLM(Model):
...
@@ -263,12 +275,13 @@ class CausalLM(Model):
)
)
def
forward
(
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Model Forward
# Model Forward
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
use_cache
=
True
,
)
)
...
@@ -283,7 +296,7 @@ class CausalLM(Model):
...
@@ -283,7 +296,7 @@ class CausalLM(Model):
)
)
with
context_manager
():
with
context_manager
():
logits
,
past
=
self
.
forward
(
logits
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
past_key_values
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
position_ids
,
batch
.
past_key_values
)
)
# List of indices to cache
# List of indices to cache
...
@@ -356,7 +369,7 @@ class CausalLM(Model):
...
@@ -356,7 +369,7 @@ class CausalLM(Model):
token_ids
=
all_input_ids
[
-
new_input_length
:]
token_ids
=
all_input_ids
[
-
new_input_length
:]
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the first prompt token
# Add NaN for the first prompt token
logprobs
=
[
float
(
"nan"
)]
+
all_logprobs
[
-
new_
input_length
:].
squeeze
(
logprobs
=
[
float
(
"nan"
)]
+
all_logprobs
[
-
input_length
:].
squeeze
(
1
1
).
tolist
()
).
tolist
()
...
@@ -394,6 +407,7 @@ class CausalLM(Model):
...
@@ -394,6 +407,7 @@ class CausalLM(Model):
if
generated_texts
:
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
next_batch_position_ids
=
batch
.
position_ids
[
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values
=
[
next_batch_past_key_values
=
[
[
[
...
@@ -411,6 +425,7 @@ class CausalLM(Model):
...
@@ -411,6 +425,7 @@ class CausalLM(Model):
]
]
else
:
else
:
next_batch_attention_mask
=
batch
.
attention_mask
next_batch_attention_mask
=
batch
.
attention_mask
next_batch_position_ids
=
batch
.
position_ids
next_batch_past_key_values
=
past
next_batch_past_key_values
=
past
next_batch_requests
=
batch
.
requests
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_next_token_choosers
=
batch
.
next_token_choosers
...
@@ -425,11 +440,15 @@ class CausalLM(Model):
...
@@ -425,11 +440,15 @@ class CausalLM(Model):
dim
=
1
,
dim
=
1
,
)
)
# Update position_ids
next_batch_position_ids
=
next_batch_position_ids
[:,
-
1
:]
+
1
next_batch
=
CausalLMBatch
(
next_batch
=
CausalLMBatch
(
batch_id
=
batch
.
batch_id
,
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
requests
=
next_batch_requests
,
input_ids
=
next_batch_input_ids
,
input_ids
=
next_batch_input_ids
,
attention_mask
=
next_batch_attention_mask
,
attention_mask
=
next_batch_attention_mask
,
position_ids
=
next_batch_position_ids
,
past_key_values
=
next_batch_past_key_values
,
past_key_values
=
next_batch_past_key_values
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
all_logprobs
=
next_batch_all_logprobs
,
all_logprobs
=
next_batch_all_logprobs
,
...
...
server/text_generation/models/galactica.py
View file @
1f570d18
...
@@ -116,6 +116,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -116,6 +116,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
pad_to_multiple_of
=
pad_to_multiple_of
,
pad_to_multiple_of
=
pad_to_multiple_of
,
return_token_type_ids
=
False
,
return_token_type_ids
=
False
,
).
to
(
device
)
).
to
(
device
)
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
return
cls
(
...
@@ -123,6 +125,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -123,6 +125,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests
=
pb
.
requests
,
requests
=
pb
.
requests
,
input_ids
=
tokenized_inputs
[
"input_ids"
],
input_ids
=
tokenized_inputs
[
"input_ids"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
position_ids
=
position_ids
,
past_key_values
=
None
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
...
@@ -330,10 +333,11 @@ class GalacticaSharded(Galactica):
...
@@ -330,10 +333,11 @@ class GalacticaSharded(Galactica):
if
name
==
"model.decoder.embed_tokens.weight"
:
if
name
==
"model.decoder.embed_tokens.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
use_cache
=
True
,
)
)
...
...
server/text_generation/models/santacoder.py
View file @
1f570d18
...
@@ -42,10 +42,9 @@ class SantaCoder(CausalLM):
...
@@ -42,10 +42,9 @@ class SantaCoder(CausalLM):
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
model_name
,
torch_dtype
=
dtype
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
load_in_8bit
=
quantize
,
trust_remote_code
=
True
,
# required
trust_remote_code
=
True
,
# required
).
eval
()
).
to
(
device
).
eval
()
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -57,31 +56,3 @@ class SantaCoder(CausalLM):
...
@@ -57,31 +56,3 @@ class SantaCoder(CausalLM):
return
self
.
tokenizer
.
decode
(
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
)
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# FIXME: current forward with past is bugged for bigcode/santacoder because past_key_values does not have
# the correct shape ([batch_size, D, seq_length] instead of [batch_size, seq_length D]
# this leads to position_ids being wrong
input_length
=
input_ids
.
shape
[
-
1
]
past_key_values_length
=
(
0
if
past_key_values
is
None
else
past_key_values
[
0
][
0
].
shape
[
-
1
]
)
position_ids
=
torch
.
arange
(
past_key_values_length
,
input_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
).
view
(
1
,
input_length
)
# Model Forward
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
position_ids
=
position_ids
,
use_cache
=
True
,
)
return
outputs
.
logits
,
outputs
.
past_key_values
server/text_generation/models/seq2seq_lm.py
View file @
1f570d18
...
@@ -449,7 +449,7 @@ class Seq2SeqLM(Model):
...
@@ -449,7 +449,7 @@ class Seq2SeqLM(Model):
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the bos token
# Add NaN for the bos token
logprobs
=
[
float
(
"nan"
)]
+
decoder_logprobs
[
logprobs
=
[
float
(
"nan"
)]
+
decoder_logprobs
[
-
new_
decoder_input_length
:
-
decoder_input_length
:
].
tolist
()
].
tolist
()
# Add to the list of finished generations with the original request
# Add to the list of finished generations with the original request
generated_texts
.
append
(
generated_texts
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment