Unverified Commit 9ecfa16b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Speculative (#1308)

parent 3238c491
......@@ -60,30 +60,30 @@ checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87"
[[package]]
name = "anstyle-parse"
version = "0.2.2"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140"
checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.0.0"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b"
checksum = "a3a318f1f38d2418400f8209655bfd825785afd25aa30bb7ba6cc792e4596748"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.1"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628"
checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7"
dependencies = [
"anstyle",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
......@@ -333,9 +333,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
version = "4.4.10"
version = "4.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272"
checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2"
dependencies = [
"clap_builder",
"clap_derive",
......@@ -343,9 +343,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.4.9"
version = "4.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1"
checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb"
dependencies = [
"anstream",
"anstyle",
......@@ -392,9 +392,9 @@ dependencies = [
[[package]]
name = "core-foundation"
version = "0.9.3"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
......@@ -402,9 +402,9 @@ dependencies = [
[[package]]
name = "core-foundation-sys"
version = "0.8.4"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]]
name = "cpufeatures"
......@@ -549,9 +549,9 @@ dependencies = [
[[package]]
name = "deranged"
version = "0.3.9"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3"
checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc"
dependencies = [
"powerfmt",
]
......@@ -1218,9 +1218,9 @@ dependencies = [
[[package]]
name = "linux-raw-sys"
version = "0.4.11"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829"
checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456"
[[package]]
name = "lock_api"
......@@ -1612,9 +1612,9 @@ dependencies = [
[[package]]
name = "openssl"
version = "0.10.60"
version = "0.10.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800"
checksum = "6b8419dc8cc6d866deb801274bba2e6f8f6108c1bb7fcc10ee5ab864931dbb45"
dependencies = [
"bitflags 2.4.1",
"cfg-if",
......@@ -1644,9 +1644,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-sys"
version = "0.9.96"
version = "0.9.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f"
checksum = "c3eaad34cdd97d81de97964fc7f29e2d104f483840d906ef56daa1912338460b"
dependencies = [
"cc",
"libc",
......@@ -2309,15 +2309,15 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.25"
version = "0.38.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc99bc2d4f1fed22595588a013687477aedf3cdcfb26558c559edb67b4d9b22e"
checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a"
dependencies = [
"bitflags 2.4.1",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
......@@ -2567,9 +2567,9 @@ dependencies = [
[[package]]
name = "slotmap"
version = "1.0.6"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342"
checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a"
dependencies = [
"version_check",
]
......@@ -3835,18 +3835,18 @@ dependencies = [
[[package]]
name = "zerocopy"
version = "0.7.26"
version = "0.7.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e97e415490559a91254a2979b4829267a57d2fcd741a98eee8b722fb57289aa0"
checksum = "7d6f15f7ade05d2a4935e34a457b936c23dc70a05cc1d97133dc99e7a3fe0f0e"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.26"
version = "0.7.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd7e48ccf166952882ca8bd778a43502c64f33bf94c12ebe2a7f08e5a0f6689f"
checksum = "dbbad221e3f78500350ecbd7dfa4e63ef945c05f4c61cb7f4d3f84cd0bba649b"
dependencies = [
"proc-macro2",
"quote",
......
......@@ -67,6 +67,14 @@ Options:
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
```
## SPECULATE
```shell
--speculate <SPECULATE>
The number of input_ids to speculate on If using a medusa model, the heads will be picked up automatically Other wise, it will use n-gram speculation which is relatively free in terms of compute, but the speedup heavily depends on the task
[env: SPECULATE=]
```
## DTYPE
```shell
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 338,
"logprob": -10.0078125,
"text": "is"
},
{
"id": 21784,
"logprob": -15.515625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -2.8847656,
"text": "Learning"
},
{
"id": 29973,
"logprob": -4.140625,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.1582031,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.23083496,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": 0.0,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": 0.0,
"special": false,
"text": " learning"
},
{
"id": 29892,
"logprob": -0.61816406,
"special": false,
"text": ","
},
{
"id": 607,
"logprob": -0.7089844,
"special": false,
"text": " which"
},
{
"id": 508,
"logprob": -1.7724609,
"special": false,
"text": " can"
},
{
"id": 367,
"logprob": 0.0,
"special": false,
"text": " be"
},
{
"id": 5545,
"logprob": 0.0,
"special": false,
"text": " considered"
},
{
"id": 408,
"logprob": -0.3869629,
"special": false,
"text": " as"
}
]
},
"generated_text": "What is Deep Learning?\nDeep learning, which can be considered as"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2753906,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.48046875,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1845703,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.5727539,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.00010967255,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.04510498,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.00020992756,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.0046539307,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025844574,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2724609,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.47729492,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2724609,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.47729492,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2724609,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.47729492,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
}
]
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1724,
"logprob": -10.734375,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"text": "is"
},
{
"id": 21784,
"logprob": -9.2890625,
"text": "Deep"
},
{
"id": 29257,
"logprob": -1.2753906,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.48046875,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.1845703,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.5727539,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108122826,
"special": false,
"text": "ep"
},
{
"id": 6509,
"logprob": -0.1239624,
"special": false,
"text": " learning"
},
{
"id": 338,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.01852417,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004787445,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00026226044,
"special": false,
"text": " learning"
}
]
},
"generated_text": "\nDeep learning is a subset of machine learning"
}
import pytest
@pytest.fixture(scope="module")
def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_medusa(flash_medusa_handle):
await flash_medusa_handle.health(300)
return flash_medusa_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
assert responses == response_snapshot
......@@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
)
assert response.details.generated_tokens == 10
assert response.generated_text == ": Let n = 10 - 1"
assert response == response_snapshot
......@@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot
......@@ -155,6 +155,13 @@ struct Args {
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The number of input_ids to speculate on
/// If using a medusa model, the heads will be picked up automatically
/// Other wise, it will use n-gram speculation which is relatively free
/// in terms of compute, but the speedup heavily depends on the task.
#[clap(long, env)]
speculate: Option<usize>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
......@@ -375,6 +382,7 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
trust_remote_code: bool,
uds_path: String,
......@@ -432,6 +440,11 @@ fn shard_manager(
shard_args.push(quantize.to_string())
}
if let Some(speculate) = speculate {
shard_args.push("--speculate".to_string());
shard_args.push(speculate.to_string())
}
if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())
......@@ -882,6 +895,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
......@@ -896,6 +910,7 @@ fn spawn_shards(
model_id,
revision,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
......
......@@ -7,7 +7,9 @@ const seed = 0;
const host = __ENV.HOST || '127.0.0.1:8000';
const timePerToken = new Trend('time_per_token', true);
const throughput = new Counter('tokens_per_s');
const tokens = new Counter('tokens');
const new_tokens = new Counter('new_tokens');
const input_tokens = new Counter('input_tokens');
randomSeed(seed);
// const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json"))
......@@ -19,7 +21,7 @@ export function get_options(reference_latency_ms){
thresholds: {
http_req_failed: ['rate==0'],
time_per_token: [{
threshold: `p(50)<${3 * reference_latency_ms}`,
threshold: `p(50)<${5 * reference_latency_ms}`,
abortOnFail: true,
delayAbortEval: '10s'
}],
......@@ -28,7 +30,7 @@ export function get_options(reference_latency_ms){
load_test: {
executor: 'constant-arrival-rate',
duration: '60s',
preAllocatedVUs: 100,
preAllocatedVUs: 10,
rate: 10,
timeUnit: '1s',
},
......@@ -48,17 +50,22 @@ export function run(host, generate_payload, max_new_tokens) {
return;
}
check(res, {
'Post status is 200': (r) => res.status === 200,
});
const n_tokens = max_new_tokens;
const timings = res.timings.duration;
const duration = res.timings.duration;
if (res.status === 200) {
const latency_ms_per_token = timings / n_tokens;
const body = res.json();
const n_tokens = body.details.tokens.length;
const latency_ms_per_token = duration / n_tokens;
timePerToken.add(latency_ms_per_token);
const latency_in_s = latency_ms_per_token / 1000;
const individual_throughput = 1 / latency_in_s;
throughput.add(individual_throughput);
const _input_tokens = body.details.prefill.length;
tokens.add(n_tokens + _input_tokens);
input_tokens.add(_input_tokens);
new_tokens.add(n_tokens);
}
}
import { get_options, run } from "./common.js";
const reference_latency_ms = 30;
const reference_latency_ms = 70;
const host = __ENV.HOST || '127.0.0.1:8000';
const max_new_tokens = 50;
function generate_payload(gpt){
const input = gpt["conversations"][0]["value"];
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "temperature" : 0.5}}
return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}}
}
export const options = get_options(reference_latency_ms);
......
syntax = "proto3";
package generate.v1;
package generate.v2;
service TextGenerationService {
/// Model Info
......@@ -32,6 +32,7 @@ message InfoResponse {
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
}
/// Empty request
......@@ -135,43 +136,27 @@ message GeneratedText {
optional uint64 seed = 4;
}
message PrefillTokens {
/// Prefill Token IDs
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Prefill Logprobs
/// Logprobs
repeated float logprobs = 2;
/// Prefill tokens
/// tokens
repeated string texts = 3;
}
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
PrefillTokens prefill_tokens = 2;
/// Token ID
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 7;
optional GeneratedText generated_text = 4;
/// Top tokens
TopTokens top_tokens = 8;
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {
......
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v2::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
......
......@@ -6,11 +6,11 @@ mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
PrefillTokens, Request, StoppingCriteriaParameters,
Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
......
......@@ -9,7 +9,7 @@ use std::sync::{
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
......@@ -50,10 +50,11 @@ impl Infer {
max_concurrent_requests: usize,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size);
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
......@@ -523,50 +524,63 @@ fn send_responses(
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text,
text: text.to_string(),
logprob,
special,
}),
)
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: generated_text.clone(),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
Ok(stopped)
}
......@@ -591,7 +605,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
Prefill(Tokens),
// Intermediate messages
Intermediate {
token: Token,
......
......@@ -34,7 +34,12 @@ pub(crate) struct Queue {
}
impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
pub(crate) fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
......@@ -43,6 +48,7 @@ impl Queue {
requires_padding,
block_size,
window_size,
speculate,
queue_receiver,
));
......@@ -91,9 +97,10 @@ async fn queue_task(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size, window_size);
let mut state = State::new(requires_padding, block_size, window_size, speculate);
while let Some(cmd) = receiver.recv().await {
match cmd {
......@@ -136,10 +143,18 @@ struct State {
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
}
impl State {
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
fn new(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
speculate: u32,
) -> Self {
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
......@@ -147,6 +162,7 @@ impl State {
requires_padding,
block_size,
window_size,
speculate,
}
}
......@@ -229,7 +245,7 @@ impl State {
}
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
......@@ -359,7 +375,7 @@ mod tests {
#[test]
fn test_append() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
......@@ -375,7 +391,7 @@ mod tests {
#[test]
fn test_next_batch_empty() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
......@@ -383,7 +399,7 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -415,7 +431,7 @@ mod tests {
#[test]
fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None);
let mut state = State::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -448,14 +464,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
......@@ -463,7 +479,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -496,7 +512,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -519,9 +535,28 @@ mod tests {
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
queue.append(entry2);
// Budget of 1 is not enough
assert!(queue.next_batch(None, 1, 1).await.is_none());
let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
assert_eq!(batch.id, 0);
assert_eq!(batch.size, 2);
}
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None);
let queue = Queue::new(false, 1, None, 0);
let (entry, _) = default_entry();
queue.append(entry);
......
......@@ -596,6 +596,7 @@ pub async fn run(
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
);
......
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: BRANCH=main
build-vllm-cuda: build-vllm
vllm-cuda:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/vllm-project/vllm.git vllm
build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
cd vllm && python setup.py build
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
build-vllm-rocm: build-vllm
install-vllm-cuda: build-vllm-cuda
pip uninstall vllm -y || true
cd vllm && python setup.py install
vllm:
vllm-rocm:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
git clone https://github.com/fxmarty/vllm-public.git vllm
build-vllm: vllm
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
build-vllm-rocm: vllm-rocm
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
cd vllm && python setup.py build
install-vllm: build-vllm
install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true
cd vllm && python setup.py install
......@@ -133,8 +133,8 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 10264 for generation in generations])
assert all([generation.token_text == "Test" for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
assert generations[0].request_id == 0
......
......@@ -129,8 +129,8 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 13 for generation in generations])
assert all([generation.token_text == "." for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
assert generations[0].request_id == 0
......
......@@ -151,8 +151,8 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == " " for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
assert generations[0].request_id == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment