Commit b3b7ea0d authored by OlivierDehaene's avatar OlivierDehaene
Browse files

feat: Use json formatter by default in docker image

parent 3cf6368c
...@@ -2175,6 +2175,16 @@ dependencies = [ ...@@ -2175,6 +2175,16 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-serde"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
dependencies = [
"serde",
"tracing-core",
]
[[package]] [[package]]
name = "tracing-subscriber" name = "tracing-subscriber"
version = "0.3.16" version = "0.3.16"
...@@ -2182,11 +2192,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2182,11 +2192,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [ dependencies = [
"nu-ansi-term", "nu-ansi-term",
"serde",
"serde_json",
"sharded-slab", "sharded-slab",
"smallvec", "smallvec",
"thread_local", "thread_local",
"tracing-core", "tracing-core",
"tracing-log", "tracing-log",
"tracing-serde",
] ]
[[package]] [[package]]
......
...@@ -73,4 +73,4 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca ...@@ -73,4 +73,4 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
# Install launcher # Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --json-output
\ No newline at end of file \ No newline at end of file
# LLM Text Generation Inference # Text Generation Inference
<div align="center"> <div align="center">
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
</div> </div>
A Rust and gRPC server for large language models text generation inference. A Rust and gRPC server for text generation inference.
## Features ## Features
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Dynamic bathing of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput - [Dynamic bathing of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
......
...@@ -10,4 +10,4 @@ clap = { version = "4.0.15", features = ["derive", "env"] } ...@@ -10,4 +10,4 @@ clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = "3.2.3" ctrlc = "3.2.3"
subprocess = "0.2.9" subprocess = "0.2.9"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = "0.3.16" tracing-subscriber = { version = "0.3.16", features = ["json"] }
...@@ -39,11 +39,11 @@ struct Args { ...@@ -39,11 +39,11 @@ struct Args {
master_addr: String, master_addr: String,
#[clap(default_value = "29500", long, env)] #[clap(default_value = "29500", long, env)]
master_port: usize, master_port: usize,
#[clap(long, env)]
json_output: bool,
} }
fn main() -> ExitCode { fn main() -> ExitCode {
tracing_subscriber::fmt().compact().with_ansi(false).init();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
model_name, model_name,
...@@ -57,8 +57,15 @@ fn main() -> ExitCode { ...@@ -57,8 +57,15 @@ fn main() -> ExitCode {
shard_uds_path, shard_uds_path,
master_addr, master_addr,
master_port, master_port,
json_output,
} = Args::parse(); } = Args::parse();
if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
// By default we only have one master shard // By default we only have one master shard
let num_shard = num_shard.unwrap_or(1); let num_shard = num_shard.unwrap_or(1);
...@@ -139,24 +146,30 @@ fn main() -> ExitCode { ...@@ -139,24 +146,30 @@ fn main() -> ExitCode {
// All shard started // All shard started
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut argv = vec![
"text-generation-router".to_string(),
"--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(),
"--max-input-length".to_string(),
max_input_length.to_string(),
"--max-batch-size".to_string(),
max_batch_size.to_string(),
"--max-waiting-tokens".to_string(),
max_waiting_tokens.to_string(),
"--port".to_string(),
port.to_string(),
"--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path),
"--tokenizer-name".to_string(),
model_name,
];
if json_output {
argv.push("--json-output".to_string());
}
let mut webserver = match Popen::create( let mut webserver = match Popen::create(
&[ &argv,
"text-generation-router",
"--max-concurrent-requests",
&max_concurrent_requests.to_string(),
"--max-input-length",
&max_input_length.to_string(),
"--max-batch-size",
&max_batch_size.to_string(),
"--max-waiting-tokens",
&max_waiting_tokens.to_string(),
"--port",
&port.to_string(),
"--master-shard-uds-path",
&format!("{}-0", shard_uds_path),
"--tokenizer-name",
&model_name,
],
PopenConfig { PopenConfig {
stdout: Redirection::Pipe, stdout: Redirection::Pipe,
stderr: Redirection::Pipe, stderr: Redirection::Pipe,
......
...@@ -24,5 +24,5 @@ thiserror = "1.0.37" ...@@ -24,5 +24,5 @@ thiserror = "1.0.37"
tokenizers = "0.13.0" tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = "0.3.15" tracing-subscriber = { version = "0.3.15", features = ["json"] }
...@@ -25,6 +25,8 @@ struct Args { ...@@ -25,6 +25,8 @@ struct Args {
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)]
json_output: bool,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
...@@ -40,11 +42,16 @@ fn main() -> Result<(), std::io::Error> { ...@@ -40,11 +42,16 @@ fn main() -> Result<(), std::io::Error> {
master_shard_uds_path, master_shard_uds_path,
tokenizer_name, tokenizer_name,
validation_workers, validation_workers,
json_output,
} = args; } = args;
tracing_subscriber::fmt().compact().with_ansi(false).init(); if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
if validation_workers == 1 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
......
...@@ -88,14 +88,6 @@ grpcio = ">=1.50.0" ...@@ -88,14 +88,6 @@ grpcio = ">=1.50.0"
protobuf = ">=4.21.6,<5.0dev" protobuf = ">=4.21.6,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]]
name = "joblib"
version = "1.2.0"
description = "Lightweight pipelining with Python functions"
category = "main"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.23.4" version = "1.23.4"
...@@ -210,10 +202,13 @@ category = "main" ...@@ -210,10 +202,13 @@ category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
[extras]
bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "50d9d44577a0222f125c770732d5f88807378573bd7386036eb5c79fc2a7c552" content-hash = "224b1e379d6105fe911bff4563946a90dfa6ff5918cf2e7be59f8d4f7c5cd7cf"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
...@@ -330,10 +325,6 @@ grpcio-tools = [ ...@@ -330,10 +325,6 @@ grpcio-tools = [
{file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"},
{file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"},
] ]
joblib = [
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
]
numpy = [ numpy = [
{file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"}, {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"},
{file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"}, {file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"},
......
...@@ -14,7 +14,6 @@ grpcio = "^1.49.1" ...@@ -14,7 +14,6 @@ grpcio = "^1.49.1"
typer = "^0.6.1" typer = "^0.6.1"
grpcio-reflection = "^1.49.1" grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0" accelerate = "^0.12.0"
joblib = "^1.2.0"
bitsandbytes = "^0.35.1" bitsandbytes = "^0.35.1"
[tool.poetry.extras] [tool.poetry.extras]
......
...@@ -15,7 +15,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: ...@@ -15,7 +15,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
return Model(model_name) return Model(model_name)
else: else:
if sharded: if sharded:
raise ValueError("sharded is only supported for BLOOM") raise ValueError("sharded is only supported for BLOOM models")
if quantize: if quantize:
raise ValueError("Quantization is only supported for BLOOM models") raise ValueError("Quantization is only supported for BLOOM models")
......
...@@ -20,7 +20,7 @@ class Model: ...@@ -20,7 +20,7 @@ class Model:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto" model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None
).eval() ).eval()
self.num_heads = self.model.config.num_attention_heads self.num_heads = self.model.config.num_attention_heads
......
import concurrent
import os import os
import signal
import torch import torch
import torch.distributed import torch.distributed
from datetime import timedelta from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from joblib import Parallel, delayed
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm from tqdm import tqdm
...@@ -124,8 +126,9 @@ def download_weights(model_name, extension=".safetensors"): ...@@ -124,8 +126,9 @@ def download_weights(model_name, extension=".safetensors"):
download_function = partial( download_function = partial(
hf_hub_download, repo_id=model_name, local_files_only=False hf_hub_download, repo_id=model_name, local_files_only=False
) )
# FIXME: fix the overlapping progress bars
files = Parallel(n_jobs=5)( executor = ThreadPoolExecutor(max_workers=5)
delayed(download_function)(filename=filename) for filename in tqdm(filenames) futures = [executor.submit(download_function, filename=filename) for filename in filenames]
) files = [file for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))]
return files return files
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment