Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
895c5f15
Unverified
Commit
895c5f15
authored
Jun 02, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 02, 2023
Browse files
feat(server): only compute prefill logprobs when asked (#406)
Close #288
parent
83b84486
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
143 additions
and
43 deletions
+143
-43
router/src/health.rs
router/src/health.rs
+1
-0
router/src/lib.rs
router/src/lib.rs
+4
-0
router/src/queue.rs
router/src/queue.rs
+2
-0
router/src/server.rs
router/src/server.rs
+12
-7
router/src/validation.rs
router/src/validation.rs
+5
-0
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+1
-0
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+1
-0
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+2
-0
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+1
-0
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+2
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+3
-0
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+3
-0
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+3
-0
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+3
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+99
-33
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+1
-1
No files found.
router/src/health.rs
View file @
895c5f15
...
...
@@ -34,6 +34,7 @@ impl Health {
id
:
LIVENESS_ID
,
inputs
:
"liveness"
.to_string
(),
truncate
:
10
,
prefill_logprobs
:
false
,
parameters
:
Some
(
NextTokenChooserParameters
{
temperature
:
1.0
,
top_k
:
0
,
...
...
router/src/lib.rs
View file @
895c5f15
...
...
@@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
#[schema(default
=
"true"
)]
pub
details
:
bool
,
#[serde(default)]
#[schema(default
=
"true"
)]
pub
decoder_input_details
:
bool
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0
,
nullable
=
true
,
...
...
@@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
truncate
:
None
,
watermark
:
false
,
details
:
false
,
decoder_input_details
:
false
,
seed
:
None
,
}
}
...
...
router/src/queue.rs
View file @
895c5f15
...
...
@@ -201,6 +201,7 @@ impl State {
batch_requests
.push
(
Request
{
id
,
prefill_logprobs
:
entry
.request.decoder_input_details
,
inputs
:
entry
.request.inputs
.clone
(),
truncate
:
entry
.request.truncate
,
parameters
:
Some
(
entry
.request.parameters
.clone
()),
...
...
@@ -281,6 +282,7 @@ mod tests {
inputs
:
""
.to_string
(),
input_length
:
0
,
truncate
:
0
,
decoder_input_details
:
false
,
parameters
:
NextTokenChooserParameters
{
temperature
:
0.0
,
top_k
:
0
,
...
...
router/src/server.rs
View file @
895c5f15
...
...
@@ -160,7 +160,7 @@ async fn generate(
add_prompt
=
Some
(
req
.0
.inputs
.clone
());
}
let
details
=
req
.0
.parameters.details
;
let
details
=
req
.0
.parameters.details
||
req
.0
.parameters.decoder_input_details
;
// Inference
let
(
response
,
best_of_responses
)
=
match
req
.0
.parameters.best_of
{
...
...
@@ -364,7 +364,17 @@ async fn generate_stream(
let
details
=
req
.0
.parameters.details
;
let
best_of
=
req
.0
.parameters.best_of
.unwrap_or
(
1
);
if
best_of
==
1
{
if
best_of
!=
1
{
let
err
=
InferError
::
from
(
ValidationError
::
BestOfStream
);
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"validation"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
));
}
else
if
req
.0
.parameters.decoder_input_details
{
let
err
=
InferError
::
from
(
ValidationError
::
PrefillDetailsStream
);
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"validation"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
));
}
else
{
match
infer
.generate_stream
(
req
.0
)
.instrument
(
info_span!
(
parent
:
&
span
,
"async_stream"
))
.await
{
// Keep permit as long as generate_stream lives
Ok
((
_
permit
,
mut
response_stream
))
=>
{
...
...
@@ -474,11 +484,6 @@ async fn generate_stream(
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
));
}
}
else
{
let
err
=
InferError
::
from
(
ValidationError
::
BestOfStream
);
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"validation"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
));
}
};
...
...
router/src/validation.rs
View file @
895c5f15
...
...
@@ -145,6 +145,7 @@ impl Validation {
truncate
,
seed
,
watermark
,
decoder_input_details
,
..
}
=
request
.parameters
;
...
...
@@ -261,6 +262,7 @@ impl Validation {
Ok
(
ValidGenerateRequest
{
inputs
,
decoder_input_details
,
input_length
:
input_length
as
u32
,
truncate
:
truncate
.unwrap_or
(
self
.max_input_length
)
as
u32
,
parameters
,
...
...
@@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
pub
inputs
:
String
,
pub
input_length
:
u32
,
pub
truncate
:
u32
,
pub
decoder_input_details
:
bool
,
pub
parameters
:
NextTokenChooserParameters
,
pub
stopping_parameters
:
StoppingCriteriaParameters
,
}
...
...
@@ -351,6 +354,8 @@ pub enum ValidationError {
BestOfSeed
,
#[error(
"`best_of` != 1 is not supported when streaming tokens"
)]
BestOfStream
,
#[error(
"`decoder_input_details` == true is not supported when streaming tokens"
)]
PrefillDetailsStream
,
#[error(
"`temperature` must be strictly positive"
)]
Temperature
,
#[error(
"`repetition_penalty` must be strictly positive"
)]
...
...
server/tests/models/test_bloom.py
View file @
895c5f15
...
...
@@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
prefill_logprobs
=
True
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
...
...
server/tests/models/test_causal_lm.py
View file @
895c5f15
...
...
@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
prefill_logprobs
=
True
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
...
...
server/tests/models/test_santacoder.py
View file @
895c5f15
...
...
@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"def"
,
prefill_logprobs
=
True
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
...
...
@@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"<fim-prefix>def<fim-suffix>world<fim-middle>"
,
prefill_logprobs
=
True
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
...
...
server/tests/models/test_seq2seq_lm.py
View file @
895c5f15
...
...
@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
prefill_logprobs
=
True
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
...
...
server/text_generation_server/models/causal_lm.py
View file @
895c5f15
...
...
@@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
).
to
(
device
)
for
_
in
pb
.
requests
:
input_len
=
tokenized_inputs
[
"input_ids"
].
shape
[
1
]
prefix_offsets
.
append
(
0
)
prefix_offsets
.
append
(
input_len
-
5
)
read_offsets
.
append
(
input_len
)
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
...
...
@@ -617,7 +617,7 @@ class CausalLM(Model):
generated_text
=
None
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
if
stopping_criteria
.
current_tokens
==
1
and
request
.
prefill_logprobs
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
torch
.
log_softmax
(
logits
,
-
1
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
895c5f15
...
...
@@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
hidden_states
,
present
=
self
.
model
(
input_ids
,
...
...
@@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
past_key_values
,
pre_allocate_past_size
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
if
self
.
model
.
tp_embeddings
:
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
895c5f15
...
...
@@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
hidden_states
,
present
=
self
.
gpt_neox
(
input_ids
,
...
...
@@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
past_key_values
,
pre_allocate_past_size
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
embed_out
(
hidden_states
)
if
self
.
gpt_neox
.
tp_embeddings
:
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
895c5f15
...
...
@@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
...
...
@@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
past_key_values
,
pre_allocate_past_size
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
if
self
.
transformer
.
tp_embeddings
:
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
895c5f15
...
...
@@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
...
...
@@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
past_key_values
,
pre_allocate_past_size
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
if
self
.
transformer
.
tp_embeddings
:
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
895c5f15
...
...
@@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
past_key_values
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
int
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices
:
Optional
[
torch
.
Tensor
]
prefill_next_token_indices
:
Optional
[
torch
.
tensor
]
prefill_cu_outlens
:
Optional
[
List
[
int
]]
# All tokens
all_input_ids
:
List
[
List
[
int
]]
all_input_ids_tensor
:
torch
.
Tensor
...
...
@@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids
=
[]
requests_idx_mapping
=
{}
all_prefill_logprobs
=
True
no_prefill_logprobs
=
True
prefill_head_indices
=
[]
prefill_next_token_indices
=
[]
prefill_cu_outlens
=
[
0
]
next_token_chooser_parameters
=
[]
stopping_criterias
=
[]
# Cumulative length
cumulative_length
=
0
prefill_out_cumulative_length
=
0
max_tokens
=
0
max_length
=
0
...
...
@@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
max_seqlen
=
max
(
max_seqlen
,
input_length
)
input_lengths
.
append
(
input_length
)
prefix_offsets
.
append
(
0
)
prefix_offsets
.
append
(
input_length
-
5
)
read_offsets
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
# Position ids
position_ids
.
append
(
np
.
arange
(
0
,
input_length
))
request_position_ids
=
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
)
position_ids
.
append
(
request_position_ids
)
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
...
...
@@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
max_new_tokens
=
stopping_criteria
.
max_new_tokens
stopping_criterias
.
append
(
stopping_criteria
)
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
if
r
.
prefill_logprobs
:
prefill_head_indices
.
append
(
request_position_ids
+
cumulative_length
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
+
input_length
-
1
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
input_length
)
prefill_out_cumulative_length
+=
input_length
else
:
prefill_head_indices
.
append
(
torch
.
tensor
(
[
cumulative_length
+
input_length
-
1
],
dtype
=
torch
.
int32
)
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
1
)
prefill_out_cumulative_length
+=
1
# Update
cumulative_length
+=
input_length
max_tokens
+=
input_length
+
max_new_tokens
...
...
@@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
for
i
,
input_ids
in
enumerate
(
all_input_ids
):
all_input_ids_tensor
[
i
,
:
len
(
input_ids
)]
=
input_ids
if
len
(
pb
.
requests
)
>
1
:
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
# Create tensors on device
input_ids
=
torch
.
tensor
(
np
.
concatenate
(
all_input_ids
),
dtype
=
torch
.
int64
,
device
=
device
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
all_input_ids_tensor
=
torch
.
tensor
(
all_input_ids_tensor
,
dtype
=
torch
.
int64
,
device
=
device
)
position_ids
=
torch
.
tensor
(
np
.
concatenate
(
position_ids
),
dtype
=
torch
.
int32
,
device
=
device
)
position_ids
=
torch
.
tensor
(
position_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
device
=
device
,
dtype
=
torch
.
int32
)
if
all_prefill_logprobs
:
prefill_head_indices
=
None
prefill_next_token_indices
=
cu_seqlens
[
1
:]
-
1
elif
no_prefill_logprobs
:
prefill_head_indices
=
cu_seqlens
[
1
:]
-
1
prefill_next_token_indices
=
None
else
:
prefill_head_indices
=
torch
.
tensor
(
torch
.
cat
(
prefill_head_indices
),
dtype
=
torch
.
int64
,
device
=
device
)
prefill_next_token_indices
=
torch
.
tensor
(
prefill_next_token_indices
,
dtype
=
torch
.
int64
,
device
=
device
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
...
...
@@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
None
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
prefill_cu_outlens
=
prefill_cu_outlens
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
prefix_offsets
=
prefix_offsets
,
...
...
@@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
None
,
prefill_next_token_indices
=
None
,
prefill_cu_outlens
=
None
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
prefix_offsets
=
prefix_offsets
,
...
...
@@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
None
,
prefill_next_token_indices
=
None
,
prefill_cu_outlens
=
None
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
prefix_offsets
=
prefix_offsets
,
...
...
@@ -486,6 +545,7 @@ class FlashCausalLM(Model):
max_s
:
int
,
past_key_values
:
Optional
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
return
self
.
model
.
forward
(
...
...
@@ -496,6 +556,7 @@ class FlashCausalLM(Model):
max_s
=
max_s
,
past_key_values
=
past_key_values
,
pre_allocate_past_size
=
pre_allocate_past_size
,
lm_head_indices
=
lm_head_indices
,
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
...
...
@@ -503,9 +564,10 @@ class FlashCausalLM(Model):
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
prefill
=
batch
.
past_key_values
is
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
single_request
=
len
(
batch
)
==
1
if
prefill
and
len
(
batch
)
==
1
:
if
prefill
and
single_request
:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size
=
(
...
...
@@ -522,11 +584,12 @@ class FlashCausalLM(Model):
batch
.
max_seqlen
,
batch
.
past_key_values
,
pre_allocate_past_size
,
batch
.
prefill_head_indices
,
)
if
prefill
:
next_token_logits
=
(
out
[
-
1
:]
if
single_request
else
out
[
batch
.
cu_seqlens
[
1
:]
-
1
]
out
[
batch
.
prefill_next_token_indices
]
if
prefill_logprobs
else
out
)
else
:
next_token_logits
=
out
...
...
@@ -536,10 +599,10 @@ class FlashCausalLM(Model):
)
if
prefill
:
if
len
(
batch
)
>
1
:
if
len
(
batch
)
>
1
and
prefill_logprobs
:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices
=
batch
.
input_ids
.
new_zeros
(
len
(
batch
.
input_ids
))
prefill_tokens_indices
=
batch
.
input_ids
.
new_zeros
(
len
(
out
))
# Create batch.cu_seqlens_q for decode
batch
.
cu_seqlens_q
=
torch
.
arange
(
...
...
@@ -600,7 +663,6 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator
=
zip
(
batch
.
input_lengths
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
...
...
@@ -611,29 +673,33 @@ class FlashCausalLM(Model):
# For each member of the batch
for
i
,
(
input_length
,
stopping_criteria
,
all_input_ids
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
if
prefill
:
# Indexing metadata
out_start_index
=
batch
.
prefill_cu_outlens
[
i
]
out_end_index
=
batch
.
prefill_cu_outlens
[
i
+
1
]
out_length
=
out_end_index
-
out_start_index
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids
[
i
]
=
batch
.
position_ids
[
end_index
-
1
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if
len
(
batch
)
>
1
:
prefill_tokens_indices
[
start_index
:
end_index
-
1
]
=
batch
.
input_ids
[
start_index
+
1
:
end_index
]
else
:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices
=
batch
.
input_ids
[
start_index
+
1
:
end_index
]
if
prefill_logprobs
:
if
len
(
batch
)
>
1
:
prefill_tokens_indices
[
out_start_index
:
out_end_index
-
1
]
=
batch
.
input_ids
[
start_index
+
1
:
start_index
+
out_length
]
else
:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices
=
batch
.
input_ids
[
start_index
+
1
:
start_index
+
out_length
]
batch
.
all_input_ids_tensor
[
i
,
input_length
]
=
next_input_ids
[
i
]
...
...
@@ -644,7 +710,7 @@ class FlashCausalLM(Model):
batch
.
position_ids
=
next_position_ids
+
1
batch
.
cu_seqlens
=
batch
.
cu_seqlens
+
batch
.
cu_seqlens_q
if
prefill
:
if
prefill
and
prefill_logprobs
:
# Get prefill logprobs
prefill_logprobs_tensor
=
torch
.
log_softmax
(
out
,
-
1
)
prefill_logprobs
=
torch
.
gather
(
...
...
@@ -657,8 +723,6 @@ class FlashCausalLM(Model):
next_token_logprobs
=
next_token_logprobs
.
tolist
()
next_token_ids
=
batch
.
input_ids
.
tolist
()
cumulative_length
=
0
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
...
...
@@ -688,9 +752,6 @@ class FlashCausalLM(Model):
next_token_id
,
next_token_logprob
,
)
in
enumerate
(
iterator
):
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id
)
...
...
@@ -728,10 +789,13 @@ class FlashCausalLM(Model):
generated_text
=
None
# Prefill
if
prefill
:
if
prefill
and
request
.
prefill_logprobs
:
out_start_index
=
batch
.
prefill_cu_outlens
[
i
]
out_end_index
=
batch
.
prefill_cu_outlens
[
i
+
1
]
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs
=
[
float
(
"nan"
)]
+
prefill_logprobs
[
start_index
:
end_index
-
1
out_
start_index
:
out_
end_index
-
1
]
prefill_token_ids
=
all_input_ids
[:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
...
...
@@ -764,8 +828,10 @@ class FlashCausalLM(Model):
batch
.
prefix_offsets
[
i
]
=
prefix_offset
batch
.
read_offsets
[
i
]
=
read_offset
batch
.
all_input_ids
[
i
]
=
all_input_ids
cumulative_length
+=
input_length
batch
.
prefill_cu_outlens
=
None
batch
.
prefill_head_indices
=
None
batch
.
prefill_next_token_indices
=
None
batch
.
max_seqlen
=
batch
.
max_seqlen
+
1
# No need to return a batch if we know that all requests stopped
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
895c5f15
...
...
@@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
generated_text
=
None
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
if
stopping_criteria
.
current_tokens
==
1
and
request
.
prefill_logprobs
:
prefill_tokens
=
PrefillTokens
(
[
self
.
tokenizer
.
bos_token_id
],
[
float
(
"nan"
)],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment