Unverified Commit 5df80590 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Auto max prefill (#2797)

* Attempt at automatic max batch prefill.

* Taking into account number of shards.

* Adding more cards.

* Adding A100 + H100

* Adding a few more cards.

* Logprobs cost too much.

* h100 better name, and keep factor of 2

* Damn inflated sparse tflops.

* Typo in h100.

* Updated the flops calculation (checked with fvcore).

* chunking by default.

* Fix prefix caching for chat completion since we removed logprobs.

* More tests.

* Dropping all the prefill logprobs.

* Add a flag that enables users to get logprobs back.

* Repairing prompt token counting.

* Fixing a few tests.

* Remove some scaffolding.

* Attempting to reduces the issues (workarounds for now).
parent 8c3669b2
import { check } from 'k6';
import { scenario } from 'k6/execution';
import http from 'k6/http';
import { Trend, Counter } from 'k6/metrics';
const host = __ENV.HOST;
const model_id = __ENV.MODEL_ID;
const timePerToken = new Trend('time_per_token', true);
const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens');
const max_new_tokens = 50;
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
const shareGPT = JSON.parse(open("long.json"))
export function get_options() {
return {
thresholds: {
http_req_failed: ['rate==0'],
// time_per_token: [{
// threshold: `p(50)<${5 * reference_latency_ms}`,
// abortOnFail: true,
// delayAbortEval: '10s'
// }],
},
scenarios: {
// single_user: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 1,
// rate: 20,
// timeUnit: '1s',
// },
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300,
// stages: [
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
throughput: {
executor: 'shared-iterations',
vus: 10,
iterations: 10,
maxDuration: '120s',
},
},
};
}
function generate_payload(gpt, max_new_tokens) {
const input = gpt["conversations"][0]["value"];
return { "messages": [{ "role": "user", "content": input }], "temperature": 0, "model": `${model_id}`, "max_tokens": max_new_tokens }
}
export const options = get_options();
export default function run() {
const headers = { 'Content-Type': 'application/json' };
const query = shareGPT[scenario.iterationInTest % shareGPT.length];
const payload = JSON.stringify(generate_payload(query, max_new_tokens));
const res = http.post(`http://${host}/v1/chat/completions`, payload, {
headers,
});
if (res.status >= 400 && res.status < 500) {
return;
}
check(res, {
'Post status is 200': (res) => res.status === 200,
});
const duration = res.timings.duration;
if (res.status === 200) {
const body = res.json();
const completion_tokens = body.usage.completion_tokens;
const latency_ms_per_token = duration / completion_tokens;
timePerToken.add(latency_ms_per_token);
const prompt_tokens = body.usage.prompt_tokens;
input_tokens.add(prompt_tokens);
new_tokens.add(completion_tokens);
tokens.add(completion_tokens + prompt_tokens);
}
}
import datasets
import json
dataset = datasets.load_dataset("ccdv/govreport-summarization")
max_new_tokens = 50
conversations = []
for i, item in enumerate(dataset["test"]):
report = item["report"]
messages = [{"from": "human", "value": f"Summarize this report: ```{report}```"}]
conversations.append({"conversations": messages})
with open("long.json", "w") as f:
json.dump(conversations, f, indent=4)
# https://www.gutenberg.org/cache/epub/103/pg103.txt
from openai import OpenAI
import os
import requests
if not os.path.exists("pg103.txt"):
response = requests.get("https://www.gutenberg.org/cache/epub/103/pg103.txt")
with open("pg103.txt", "w") as f:
f.write(response.text)
length = 130000
with open("pg103.txt", "r") as f:
data = f.read()
messages = [{"role": "user", "content": data[: length * 4]}]
client = OpenAI(base_url="http://localhost:8000/v1", api_key="w")
completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct", messages=messages, max_tokens=2
)
import json
import datasets
import tqdm
def main():
dataset = datasets.load_dataset("Open-Orca/OpenOrca", split="train")
# Select only the first 2k conversations that start with a human.
max = min(2000, len(dataset))
conversations = []
for item in tqdm.tqdm(dataset, total=max):
conversation = {
"conversations": [
{"from": "human", "value": item["question"]},
],
"id": item["id"],
}
conversations.append(conversation)
if len(conversations) >= max:
break
with open("./small.json", "w") as f:
json.dump(conversations, f, indent=4)
if __name__ == "__main__":
main()
......@@ -191,7 +191,7 @@ pub enum Config {
#[serde(rename = "phi-msft")]
PhiMsft,
Phi3,
PhiMoe,
Phimoe,
Llama,
Baichuan,
Paligemma(Paligemma),
......
......@@ -647,6 +647,7 @@ enum CompletionType {
}
impl ChatCompletion {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
model: String,
system_fingerprint: String,
......@@ -655,6 +656,7 @@ impl ChatCompletion {
details: Details,
return_logprobs: bool,
tool_calls: Option<Vec<ToolCall>>,
prompt_tokens: u32,
) -> Self {
let message = match (output, tool_calls) {
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
......@@ -693,9 +695,9 @@ impl ChatCompletion {
finish_reason: details.finish_reason.format(true),
}],
usage: Usage {
prompt_tokens: details.prefill.len() as u32,
prompt_tokens,
completion_tokens: details.generated_tokens,
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
total_tokens: prompt_tokens + details.generated_tokens,
},
}
}
......@@ -919,7 +921,6 @@ impl ChatRequest {
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
......@@ -999,7 +1000,7 @@ impl ChatRequest {
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
decoder_input_details: false,
seed,
top_n_tokens: top_logprobs,
grammar,
......
......@@ -271,7 +271,9 @@ async fn generate(
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
let (headers, _, response) =
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
Ok((headers, response))
}
pub(crate) async fn generate_internal(
......@@ -279,7 +281,7 @@ pub(crate) async fn generate_internal(
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1);
......@@ -423,7 +425,7 @@ pub(crate) async fn generate_internal(
generated_text: output_text,
details,
};
Ok((headers, Json(response)))
Ok((headers, input_length, Json(response)))
}
/// Generate a stream of token using Server-Sent Events
......@@ -980,7 +982,9 @@ pub(crate) async fn completions(
span_clone,
)
.await;
result.map(|(headers, generation)| (index, headers, generation))
result.map(|(headers, input_length, generation)| {
(index, headers, input_length, generation)
})
};
responses.push(response_future);
}
......@@ -1001,7 +1005,7 @@ pub(crate) async fn completions(
let choices = generate_responses
.into_iter()
.map(|(index, headers, Json(generation))| {
.map(|(index, headers, input_length, Json(generation))| {
let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR,
......@@ -1056,9 +1060,9 @@ pub(crate) async fn completions(
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
prompt_tokens += details.prefill.len() as u32;
prompt_tokens += input_length;
completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
total_tokens += input_length + details.generated_tokens;
Ok(CompletionComplete {
finish_reason: details.finish_reason.format(true),
......@@ -1381,7 +1385,7 @@ pub(crate) async fn chat_completions(
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, Json(generation)) =
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now()
......@@ -1452,6 +1456,7 @@ pub(crate) async fn chat_completions(
generation.details.unwrap(),
logprobs,
tool_calls,
input_length,
));
// wrap generation inside a Vec to match api-inference
......
......@@ -122,7 +122,7 @@ pub(crate) async fn vertex_compatibility(
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map(|(_, _, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
......
......@@ -57,6 +57,7 @@ from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
CUDA_GRAPHS,
REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)
......@@ -292,6 +293,10 @@ class FlashCausalLMBatch(Batch):
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
### XXX: This consumes so much memory on long requests
### Deactivating it by default seems like the best course.
if not REQUEST_LOGPROBS:
r.prefill_logprobs = False
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
......@@ -1554,12 +1559,13 @@ class FlashCausalLM(Model):
)
batch_num_blocks = batch.num_blocks
num_tokens = batch.to_pb().current_tokens
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
_, _batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
f"Not enough memory to handle {num_tokens} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
......@@ -1592,6 +1598,8 @@ class FlashCausalLM(Model):
if max_input_tokens is None
else max_input_tokens
)
elif max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
del _batch, batch
self.kv_cache = []
......
......@@ -5,13 +5,14 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment