Unverified Commit 895c5f15 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): only compute prefill logprobs when asked (#406)

Close #288
parent 83b84486
...@@ -34,6 +34,7 @@ impl Health { ...@@ -34,6 +34,7 @@ impl Health {
id: LIVENESS_ID, id: LIVENESS_ID,
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
truncate: 10, truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
top_k: 0, top_k: 0,
......
...@@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters { ...@@ -125,6 +125,9 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0, exclusive_minimum = 0,
nullable = true, nullable = true,
...@@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters { ...@@ -153,6 +156,7 @@ fn default_parameters() -> GenerateParameters {
truncate: None, truncate: None,
watermark: false, watermark: false,
details: false, details: false,
decoder_input_details: false,
seed: None, seed: None,
} }
} }
......
...@@ -201,6 +201,7 @@ impl State { ...@@ -201,6 +201,7 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
...@@ -281,6 +282,7 @@ mod tests { ...@@ -281,6 +282,7 @@ mod tests {
inputs: "".to_string(), inputs: "".to_string(),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false,
parameters: NextTokenChooserParameters { parameters: NextTokenChooserParameters {
temperature: 0.0, temperature: 0.0,
top_k: 0, top_k: 0,
......
...@@ -160,7 +160,7 @@ async fn generate( ...@@ -160,7 +160,7 @@ async fn generate(
add_prompt = Some(req.0.inputs.clone()); add_prompt = Some(req.0.inputs.clone());
} }
let details = req.0.parameters.details; let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
// Inference // Inference
let (response, best_of_responses) = match req.0.parameters.best_of { let (response, best_of_responses) = match req.0.parameters.best_of {
...@@ -364,7 +364,17 @@ async fn generate_stream( ...@@ -364,7 +364,17 @@ async fn generate_stream(
let details = req.0.parameters.details; let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1); let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 { if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => { Ok((_permit, mut response_stream)) => {
...@@ -474,11 +484,6 @@ async fn generate_stream( ...@@ -474,11 +484,6 @@ async fn generate_stream(
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
} else {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} }
}; };
......
...@@ -145,6 +145,7 @@ impl Validation { ...@@ -145,6 +145,7 @@ impl Validation {
truncate, truncate,
seed, seed,
watermark, watermark,
decoder_input_details,
.. ..
} = request.parameters; } = request.parameters;
...@@ -261,6 +262,7 @@ impl Validation { ...@@ -261,6 +262,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
...@@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest { ...@@ -335,6 +337,7 @@ pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
} }
...@@ -351,6 +354,8 @@ pub enum ValidationError { ...@@ -351,6 +354,8 @@ pub enum ValidationError {
BestOfSeed, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream, BestOfStream,
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
......
...@@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
......
...@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
......
...@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="def", inputs="def",
prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
...@@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -31,6 +32,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>", inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
......
...@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
......
...@@ -104,7 +104,7 @@ class CausalLMBatch(Batch): ...@@ -104,7 +104,7 @@ class CausalLMBatch(Batch):
).to(device) ).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0) prefix_offsets.append(input_len - 5)
read_offsets.append(input_len) read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
...@@ -617,7 +617,7 @@ class CausalLM(Model): ...@@ -617,7 +617,7 @@ class CausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax( prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1 logits, -1
......
...@@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -443,6 +443,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.model( hidden_states, present = self.model(
input_ids, input_ids,
...@@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -453,6 +454,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.model.tp_embeddings: if self.model.tp_embeddings:
......
...@@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -481,6 +481,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.gpt_neox( hidden_states, present = self.gpt_neox(
input_ids, input_ids,
...@@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -491,6 +492,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings: if self.gpt_neox.tp_embeddings:
......
...@@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -752,6 +752,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
...@@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -762,6 +763,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings: if self.transformer.tp_embeddings:
......
...@@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -358,6 +358,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
): ):
hidden_states, present = self.transformer( hidden_states, present = self.transformer(
input_ids, input_ids,
...@@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -368,6 +369,8 @@ class FlashSantacoderForCausalLM(nn.Module):
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
) )
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings: if self.transformer.tp_embeddings:
......
...@@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch): ...@@ -42,6 +42,11 @@ class FlashCausalLMBatch(Batch):
past_key_values: Optional[torch.Tensor] past_key_values: Optional[torch.Tensor]
max_seqlen: int max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices: Optional[torch.Tensor]
prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
...@@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch): ...@@ -84,11 +89,18 @@ class FlashCausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
prefill_out_cumulative_length = 0
max_tokens = 0 max_tokens = 0
max_length = 0 max_length = 0
...@@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch): ...@@ -106,13 +118,14 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(0) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
position_ids.append(np.arange(0, input_length)) request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
...@@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch): ...@@ -125,6 +138,26 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update # Update
cumulative_length += input_length cumulative_length += input_length
max_tokens += input_length + max_new_tokens max_tokens += input_length + max_new_tokens
...@@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch): ...@@ -141,18 +174,35 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids): for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor[i, : len(input_ids)] = input_ids
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device # Create tensors on device
input_ids = torch.tensor( input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor( all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
position_ids = torch.tensor( position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
np.concatenate(position_ids), dtype=torch.int32, device=device
)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlens[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlens[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
...@@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch): ...@@ -162,6 +212,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_q=None, cu_seqlens_q=None,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
...@@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch): ...@@ -280,6 +333,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
...@@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch): ...@@ -415,6 +471,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
...@@ -486,6 +545,7 @@ class FlashCausalLM(Model): ...@@ -486,6 +545,7 @@ class FlashCausalLM(Model):
max_s: int, max_s: int,
past_key_values: Optional = None, past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
return self.model.forward( return self.model.forward(
...@@ -496,6 +556,7 @@ class FlashCausalLM(Model): ...@@ -496,6 +556,7 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
past_key_values=past_key_values, past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size, pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
) )
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
...@@ -503,9 +564,10 @@ class FlashCausalLM(Model): ...@@ -503,9 +564,10 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None
single_request = len(batch) == 1 single_request = len(batch) == 1
if prefill and len(batch) == 1: if prefill and single_request:
# Ask to pre-allocate kv to its max size # Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens # == number of tokens + max_new_tokens
pre_allocate_past_size = ( pre_allocate_past_size = (
...@@ -522,11 +584,12 @@ class FlashCausalLM(Model): ...@@ -522,11 +584,12 @@ class FlashCausalLM(Model):
batch.max_seqlen, batch.max_seqlen,
batch.past_key_values, batch.past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
batch.prefill_head_indices,
) )
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
else: else:
next_token_logits = out next_token_logits = out
...@@ -536,10 +599,10 @@ class FlashCausalLM(Model): ...@@ -536,10 +599,10 @@ class FlashCausalLM(Model):
) )
if prefill: if prefill:
if len(batch) > 1: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.cu_seqlens_q for decode # Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange( batch.cu_seqlens_q = torch.arange(
...@@ -600,7 +663,6 @@ class FlashCausalLM(Model): ...@@ -600,7 +663,6 @@ class FlashCausalLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.input_lengths, batch.input_lengths,
batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
) )
...@@ -611,29 +673,33 @@ class FlashCausalLM(Model): ...@@ -611,29 +673,33 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
input_length, input_length,
stopping_criteria,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if prefill: if prefill:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
# Initialize position_ids # Initialize position_ids
# In decode, we do not need this as we can just increment position ids # In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1] next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if len(batch) > 1: if prefill_logprobs:
prefill_tokens_indices[ if len(batch) > 1:
start_index : end_index - 1 prefill_tokens_indices[
] = batch.input_ids[start_index + 1 : end_index] out_start_index : out_end_index - 1
else: ] = batch.input_ids[start_index + 1 : start_index + out_length]
# Set prefill_tokens_indices to the correct slice else:
prefill_tokens_indices = batch.input_ids[ # Set prefill_tokens_indices to the correct slice
start_index + 1 : end_index prefill_tokens_indices = batch.input_ids[
] start_index + 1 : start_index + out_length
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
...@@ -644,7 +710,7 @@ class FlashCausalLM(Model): ...@@ -644,7 +710,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + 1 batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
if prefill: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather( prefill_logprobs = torch.gather(
...@@ -657,8 +723,6 @@ class FlashCausalLM(Model): ...@@ -657,8 +723,6 @@ class FlashCausalLM(Model):
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist() next_token_ids = batch.input_ids.tolist()
cumulative_length = 0
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
...@@ -688,9 +752,6 @@ class FlashCausalLM(Model): ...@@ -688,9 +752,6 @@ class FlashCausalLM(Model):
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
) in enumerate(iterator): ) in enumerate(iterator):
start_index = cumulative_length
end_index = cumulative_length + input_length
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
...@@ -728,10 +789,13 @@ class FlashCausalLM(Model): ...@@ -728,10 +789,13 @@ class FlashCausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if prefill: if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = [float("nan")] + prefill_logprobs[
start_index : end_index - 1 out_start_index : out_end_index - 1
] ]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
...@@ -764,8 +828,10 @@ class FlashCausalLM(Model): ...@@ -764,8 +828,10 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
cumulative_length += input_length
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
......
...@@ -688,7 +688,7 @@ class Seq2SeqLM(Model): ...@@ -688,7 +688,7 @@ class Seq2SeqLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment