Unverified Commit 895c5f15 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): only compute prefill logprobs when asked (#406)

Close #288
parent 83b84486
......@@ -34,6 +34,7 @@ impl Health {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
......
......@@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")]
pub details: bool,
#[serde(default)]
#[schema(default = "true")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
......@@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
truncate: None,
watermark: false,
details: false,
decoder_input_details: false,
seed: None,
}
}
......
......@@ -201,6 +201,7 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
......@@ -281,6 +282,7 @@ mod tests {
inputs: "".to_string(),
input_length: 0,
truncate: 0,
decoder_input_details: false,
parameters: NextTokenChooserParameters {
temperature: 0.0,
top_k: 0,
......
......@@ -160,7 +160,7 @@ async fn generate(
add_prompt = Some(req.0.inputs.clone());
}
let details = req.0.parameters.details;
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
......@@ -364,7 +364,17 @@ async fn generate_stream(
let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
......@@ -474,11 +484,6 @@ async fn generate_stream(
tracing::error!("{err}");
yield Ok(Event::from(err));
}
} else {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
}
};
......
......@@ -145,6 +145,7 @@ impl Validation {
truncate,
seed,
watermark,
decoder_input_details,
..
} = request.parameters;
......@@ -261,6 +262,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
......@@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
}
......@@ -351,6 +354,8 @@ pub enum ValidationError {
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream,
#[error("`temperature` must be strictly positive")]
Temperature,
#[error("`repetition_penalty` must be strictly positive")]
......
......@@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
......
......@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
......
......@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
......@@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
......
......@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
......
......@@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
......@@ -617,7 +617,7 @@ class CausalLM(Model):
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1:
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
......
......@@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.model(
input_ids,
......@@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.model.tp_embeddings:
......
......@@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.gpt_neox(
input_ids,
......@@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings:
......
......@@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
input_ids,
......@@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
......
......@@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.transformer(
input_ids,
......@@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
......
......@@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
past_key_values: Optional[torch.Tensor]
max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices: Optional[torch.Tensor]
prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
......@@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
prefill_out_cumulative_length = 0
max_tokens = 0
max_length = 0
......@@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
prefix_offsets.append(0)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
# Position ids
position_ids.append(np.arange(0, input_length))
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
......@@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
......@@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlens[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlens[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
return cls(
batch_id=pb.id,
requests=pb.requests,
......@@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens,
cu_seqlens_q=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
......@@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
......@@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
......@@ -486,6 +545,7 @@ class FlashCausalLM(Model):
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
......@@ -496,6 +556,7 @@ class FlashCausalLM(Model):
max_s=max_s,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
)
@tracer.start_as_current_span("generate_token")
......@@ -503,9 +564,10 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None
single_request = len(batch) == 1
if prefill and len(batch) == 1:
if prefill and single_request:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
......@@ -522,11 +584,12 @@ class FlashCausalLM(Model):
batch.max_seqlen,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices,
)
if prefill:
next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
else:
next_token_logits = out
......@@ -536,10 +599,10 @@ class FlashCausalLM(Model):
)
if prefill:
if len(batch) > 1:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids))
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange(
......@@ -600,7 +663,6 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.stopping_criterias,
batch.all_input_ids,
)
......@@ -611,29 +673,33 @@ class FlashCausalLM(Model):
# For each member of the batch
for i, (
input_length,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if len(batch) > 1:
prefill_tokens_indices[
start_index : end_index - 1
] = batch.input_ids[start_index + 1 : end_index]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
if prefill_logprobs:
if len(batch) > 1:
prefill_tokens_indices[
out_start_index : out_end_index - 1
] = batch.input_ids[start_index + 1 : start_index + out_length]
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
......@@ -644,7 +710,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if prefill:
if prefill and prefill_logprobs:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
......@@ -657,8 +723,6 @@ class FlashCausalLM(Model):
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
# Zipped iterator
iterator = zip(
batch.requests,
......@@ -688,9 +752,6 @@ class FlashCausalLM(Model):
next_token_id,
next_token_logprob,
) in enumerate(iterator):
start_index = cumulative_length
end_index = cumulative_length + input_length
# Append next token to all tokens
all_input_ids.append(next_token_id)
......@@ -728,10 +789,13 @@ class FlashCausalLM(Model):
generated_text = None
# Prefill
if prefill:
if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1
out_start_index : out_end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
......@@ -764,8 +828,10 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
cumulative_length += input_length
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped
......
......@@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1:
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id],
[float("nan")],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment