"vscode:/vscode.git/clone" did not exist on "e3b40d1cf4396f010f68eca3079b48352a592db3"
Unverified Commit 20c3c594 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(router): refactor API and add openAPI schemas (#53)

parent b1482d90
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer; mod infer;
mod queue; mod queue;
pub mod server; pub mod server;
...@@ -8,45 +7,55 @@ mod validation; ...@@ -8,45 +7,55 @@ mod validation;
use infer::Infer; use infer::Infer;
use queue::{Entry, Queue}; use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")] #[serde(default)]
pub temperature: f32, #[schema(
#[serde(default = "default_repetition_penalty")] exclusive_minimum = 0.0,
pub repetition_penalty: f32, nullable = true,
#[serde(default = "default_top_k")] default = "null",
pub top_k: i32, example = 0.5
#[serde(default = "default_top_p")] )]
pub top_p: f32, pub temperature: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 1.03
)]
pub repetition_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub top_p: Option<f32>,
#[serde(default = "default_do_sample")] #[serde(default = "default_do_sample")]
#[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }
fn default_temperature() -> f32 {
1.0
}
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool { fn default_do_sample() -> bool {
false false
} }
...@@ -57,10 +66,10 @@ fn default_max_new_tokens() -> u32 { ...@@ -57,10 +66,10 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
GenerateParameters { GenerateParameters {
temperature: default_temperature(), temperature: None,
repetition_penalty: default_repetition_penalty(), repetition_penalty: None,
top_k: default_top_k(), top_k: None,
top_p: default_top_p(), top_p: None,
do_sample: default_do_sample(), do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
stop: vec![], stop: vec![],
...@@ -69,42 +78,77 @@ fn default_parameters() -> GenerateParameters { ...@@ -69,42 +78,77 @@ fn default_parameters() -> GenerateParameters {
} }
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateRequest { pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String, pub inputs: String,
#[serde(default = "default_parameters")] #[serde(default = "default_parameters")]
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, ToSchema)]
pub struct Token(u32, String, f32); pub struct Token {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32,
}
#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length,
#[serde(rename = "eos_token")]
#[schema(rename = "eos_token")]
EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence,
}
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct Details { pub(crate) struct Details {
pub finish_reason: String, #[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill: Option<Vec<Token>>, pub prefill: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<Token>>, pub tokens: Option<Vec<Token>>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct GenerateResponse { pub(crate) struct GenerateResponse {
#[schema(example = "test")]
pub generated_text: String, pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub token: Token, pub token: Token,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
pub details: Option<Details>, #[schema(nullable = true, default = "null")]
pub details: Option<StreamDetails>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
#[schema(inline)]
pub error: String, pub error: String,
} }
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
StreamResponse, Validation, Infer, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
...@@ -19,6 +19,8 @@ use tokio::signal; ...@@ -19,6 +19,8 @@ use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
/// Health check method /// Health check method
#[instrument(skip(infer))] #[instrument(skip(infer))]
...@@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe ...@@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate(GenerateRequest { .generate(GenerateRequest {
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
temperature: 1.0, temperature: None,
repetition_penalty: 1.0, repetition_penalty: None,
top_k: 0, top_k: None,
top_p: 1.0, top_p: None,
do_sample: false, do_sample: false,
max_new_tokens: 1, max_new_tokens: 1,
stop: vec![], stop: Vec::new(),
details: false, details: false,
seed: None, seed: None,
}, },
...@@ -47,7 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe ...@@ -47,7 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
Ok(()) Ok(())
} }
/// Generate method /// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"})),
)
)]
#[instrument( #[instrument(
skip(infer), skip(infer),
fields( fields(
...@@ -76,7 +95,7 @@ async fn generate( ...@@ -76,7 +95,7 @@ async fn generate(
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => Some(Details {
finish_reason: response.generated_text.finish_reason, finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill), prefill: Some(response.prefill),
tokens: Some(response.tokens), tokens: Some(response.tokens),
...@@ -132,7 +151,29 @@ async fn generate( ...@@ -132,7 +151,29 @@ async fn generate(
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }
/// Generate stream method /// Generate a stream of token using Server Side Events
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [StreamResponse],
content_type="text/event-stream "),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"}),
content_type="text/event-stream "),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"}),
content_type="text/event-stream "),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"}),
content_type="text/event-stream "),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "),
)
)]
#[instrument( #[instrument(
skip(infer), skip(infer),
fields( fields(
...@@ -185,11 +226,9 @@ async fn generate_stream( ...@@ -185,11 +226,9 @@ async fn generate_stream(
} => { } => {
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => Some(StreamDetails {
finish_reason: generated_text.finish_reason, finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
prefill: None,
tokens: None,
seed: generated_text.seed, seed: generated_text.seed,
}), }),
false => None, false => None,
...@@ -265,6 +304,39 @@ pub async fn run( ...@@ -265,6 +304,39 @@ pub async fn run(
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
) { ) {
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
generate,
generate_stream,
),
components(
schemas(
GenerateRequest,
GenerateParameters,
Token,
GenerateResponse,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct ApiDoc;
// Create state // Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let infer = Infer::new( let infer = Infer::new(
...@@ -277,6 +349,7 @@ pub async fn run( ...@@ -277,6 +349,7 @@ pub async fn run(
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate)) .route("/", post(generate))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
...@@ -320,6 +393,17 @@ async fn shutdown_signal() { ...@@ -320,6 +393,17 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
} }
impl From<i32> for FinishReason {
fn from(finish_reason: i32) -> Self {
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
match finish_reason {
text_generation_client::FinishReason::Length => FinishReason::Length,
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
}
}
}
/// Convert to Axum supported formats /// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
......
...@@ -110,30 +110,58 @@ fn validate( ...@@ -110,30 +110,58 @@ fn validate(
max_input_length: usize, max_input_length: usize,
rng: &mut ThreadRng, rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
if request.parameters.temperature <= 0.0 { let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
seed,
..
} = request.parameters;
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
} }
if request.parameters.repetition_penalty <= 0.0 {
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
if repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
let top_p = top_p.unwrap_or(1.0);
if top_p <= 0.0 || top_p > 1.0 {
return Err(ValidationError::TopP); return Err(ValidationError::TopP);
} }
if request.parameters.top_k < 0 {
// Different because the proto default value is 0 while it is not a valid value
// for the user
let top_k: u32 = match top_k {
None => Ok(0),
Some(top_k) => {
if top_k <= 0 {
return Err(ValidationError::TopK); return Err(ValidationError::TopK);
} }
if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS { Ok(top_k as u32)
}
}?;
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
} }
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
if stop_sequences.len() > MAX_STOP_SEQUENCES {
return Err(ValidationError::StopSequence( return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES, MAX_STOP_SEQUENCES,
request.parameters.stop.len(), stop_sequences.len(),
)); ));
} }
// If seed is None, assign a random one // If seed is None, assign a random one
let seed = match request.parameters.seed { let seed = match seed {
None => rng.gen(), None => rng.gen(),
Some(seed) => seed, Some(seed) => seed,
}; };
...@@ -147,21 +175,10 @@ fn validate( ...@@ -147,21 +175,10 @@ fn validate(
Err(ValidationError::InputLength(input_length, max_input_length)) Err(ValidationError::InputLength(input_length, max_input_length))
} else { } else {
// Return ValidGenerateRequest // Return ValidGenerateRequest
let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
..
} = request.parameters;
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
top_k: top_k as u32, top_k,
top_p, top_p,
do_sample, do_sample,
seed, seed,
...@@ -206,7 +223,7 @@ pub enum ValidationError { ...@@ -206,7 +223,7 @@ pub enum ValidationError {
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("top_k must be strictly positive")]
TopK, TopK,
#[error("max_new_tokens must be <= {0}")] #[error("max_new_tokens must be strictly positive and <= {0}")]
MaxNewTokens(u32), MaxNewTokens(u32),
#[error("inputs must have less than {1} tokens. Given: {0}")] #[error("inputs must have less than {1} tokens. Given: {0}")]
InputLength(usize, usize), InputLength(usize, usize),
......
# BLOOM Inference Python gRPC Server # Text Generation Inference Python gRPC Server
A Python gRPC server for BLOOM Inference A Python gRPC server for Text Generation Inference
## Install ## Install
......
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.1.0" version = "0.2.0"
description = "BLOOM Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]
......
...@@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) ...@@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
...@@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi( ...@@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
...@@ -283,8 +281,7 @@ def test_batch_concatenate( ...@@ -283,8 +281,7 @@ def test_batch_concatenate(
assert len(generations) == 2 assert len(generations) == 2
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
...@@ -306,8 +303,7 @@ def test_batch_concatenate( ...@@ -306,8 +303,7 @@ def test_batch_concatenate(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
......
...@@ -9,6 +9,7 @@ from text_generation.utils import ( ...@@ -9,6 +9,7 @@ from text_generation.utils import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
LocalEntryNotFoundError, LocalEntryNotFoundError,
FinishReason,
) )
...@@ -24,13 +25,13 @@ def test_stop_sequence_criteria(): ...@@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
def test_stopping_criteria(): def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None) assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, "stop_sequence") assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
def test_stopping_criteria_eos(): def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, "eos_token") assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
def test_stopping_criteria_max(): def test_stopping_criteria_max():
...@@ -39,7 +40,7 @@ def test_stopping_criteria_max(): ...@@ -39,7 +40,7 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, "length") assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files(): def test_weight_hub_files():
......
...@@ -13,7 +13,7 @@ app = typer.Typer() ...@@ -13,7 +13,7 @@ app = typer.Typer()
@app.command() @app.command()
def serve( def serve(
model_name: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
...@@ -46,16 +46,16 @@ def serve( ...@@ -46,16 +46,16 @@ def serve(
os.getenv("MASTER_PORT", None) is not None os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True" ), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, revision, sharded, quantize, uds_path) server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command() @app.command()
def download_weights( def download_weights(
model_name: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
extension: str = ".safetensors", extension: str = ".safetensors",
): ):
utils.download_weights(model_name, revision, extension) utils.download_weights(model_id, revision, extension)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True ...@@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True
def get_model( def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
config = AutoConfig.from_pretrained(model_name, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
if sharded: if sharded:
return BLOOMSharded(model_name, revision, quantize=quantize) return BLOOMSharded(model_id, revision, quantize=quantize)
else: else:
return BLOOM(model_name, revision, quantize=quantize) return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox": elif config.model_type == "gpt_neox":
if sharded: if sharded:
return GPTNeoxSharded(model_name, revision, quantize=quantize) return GPTNeoxSharded(model_id, revision, quantize=quantize)
else: else:
return GPTNeox(model_name, revision, quantize=quantize) return GPTNeox(model_id, revision, quantize=quantize)
elif model_name.startswith("facebook/galactica"): elif model_id.startswith("facebook/galactica"):
if sharded: if sharded:
return GalacticaSharded(model_name, revision, quantize=quantize) return GalacticaSharded(model_id, revision, quantize=quantize)
else: else:
return Galactica(model_name, revision, quantize=quantize) return Galactica(model_id, revision, quantize=quantize)
elif "santacoder" in model_name: elif "santacoder" in model_id:
return SantaCoder(model_name, revision, quantize) return SantaCoder(model_id, revision, quantize)
else: else:
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
try: try:
return CausalLM(model_name, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
except Exception: except Exception:
return Seq2SeqLM(model_name, revision, quantize=quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)
...@@ -57,10 +57,10 @@ class BLOOM(CausalLM): ...@@ -57,10 +57,10 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM): class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_name.startswith("bigscience/bloom"): if not model_id.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
...@@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM): ...@@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, revision=revision, slow_but_exact=False, tp_parallel=True model_id, revision=revision, slow_but_exact=False, tp_parallel=True
) )
config.pad_token_id = 3 config.pad_token_id = 3
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m": if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_name, revision=revision, extension=".safetensors") download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files( filenames = weight_files(model_id, revision=revision, extension=".safetensors")
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")
......
...@@ -232,7 +232,7 @@ class CausalLMBatch(Batch): ...@@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...@@ -244,10 +244,10 @@ class CausalLM(Model): ...@@ -244,10 +244,10 @@ class CausalLM(Model):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
......
...@@ -149,10 +149,10 @@ class Galactica(CausalLM): ...@@ -149,10 +149,10 @@ class Galactica(CausalLM):
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_name.startswith("facebook/galactica"): if not model_id.startswith("facebook/galactica"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
...@@ -164,22 +164,20 @@ class GalacticaSharded(Galactica): ...@@ -164,22 +164,20 @@ class GalacticaSharded(Galactica):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, revision=revision, tp_parallel=True model_id, revision=revision, tp_parallel=True
) )
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "facebook/galactica-125m": if self.master and model_id == "facebook/galactica-125m":
download_weights(model_name, revision=revision, extension=".safetensors") download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files( filenames = weight_files(model_id, revision=revision, extension=".safetensors")
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")
......
...@@ -49,7 +49,7 @@ class GPTNeox(CausalLM): ...@@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
class GPTNeoxSharded(GPTNeox): class GPTNeoxSharded(GPTNeox):
def __init__( def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
...@@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox): ...@@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, revision=revision, tp_parallel=True model_id, revision=revision, tp_parallel=True
) )
# Only master download weights # Only master download weights
if self.master: if self.master:
download_weights(model_name, revision=revision, extension=".safetensors") download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files( filenames = weight_files(model_id, revision=revision, extension=".safetensors")
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")
......
...@@ -14,7 +14,7 @@ EOD = "<|endoftext|>" ...@@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM): class SantaCoder(CausalLM):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...@@ -26,7 +26,7 @@ class SantaCoder(CausalLM): ...@@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
...@@ -43,7 +43,7 @@ class SantaCoder(CausalLM): ...@@ -43,7 +43,7 @@ class SantaCoder(CausalLM):
self.model = ( self.model = (
AutoModelForCausalLM.from_pretrained( AutoModelForCausalLM.from_pretrained(
model_name, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize, load_in_8bit=quantize,
......
...@@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...@@ -301,14 +301,14 @@ class Seq2SeqLM(Model): ...@@ -301,14 +301,14 @@ class Seq2SeqLM(Model):
dtype = torch.float32 dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained( self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize,
).eval() ).eval()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
......
...@@ -7,6 +7,7 @@ from typing import List, Optional ...@@ -7,6 +7,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):
...@@ -38,7 +39,7 @@ class Batch(ABC): ...@@ -38,7 +39,7 @@ class Batch(ABC):
class GeneratedText: class GeneratedText:
text: str text: str
generated_tokens: int generated_tokens: int
finish_reason: str finish_reason: FinishReason
seed: Optional[int] seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
......
...@@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_name: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: bool, quantize: bool,
uds_path: Path, uds_path: Path,
): ):
async def serve_inner( async def serve_inner(
model_name: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
...@@ -89,7 +89,7 @@ def serve( ...@@ -89,7 +89,7 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
model = get_model(model_name, revision, sharded, quantize) model = get_model(model_id, revision, sharded, quantize)
server = aio.server(interceptors=[ExceptionInterceptor()]) server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
...@@ -109,4 +109,4 @@ def serve( ...@@ -109,4 +109,4 @@ def serve(
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_name, revision, sharded, quantize)) asyncio.run(serve_inner(model_id, revision, sharded, quantize))
...@@ -24,9 +24,11 @@ from transformers.generation.logits_process import ( ...@@ -24,9 +24,11 @@ from transformers.generation.logits_process import (
) )
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling: class Sampling:
def __init__(self, seed: int, device: str = "cpu"): def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device) self.generator = torch.Generator(device)
...@@ -129,15 +131,15 @@ class StoppingCriteria: ...@@ -129,15 +131,15 @@ class StoppingCriteria:
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1 self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True, "length" return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id: if last_token == self.eos_token_id:
return True, "eos_token" return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias: for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output): if stop_sequence_criteria(self.current_output):
return True, "stop_sequence" return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None return False, None
...@@ -180,20 +182,20 @@ def initialize_torch_distributed(): ...@@ -180,20 +182,20 @@ def initialize_torch_distributed():
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
def weight_hub_files(model_name, revision=None, extension=".safetensors"): def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub""" """Get the safetensors filenames on the hub"""
api = HfApi() api = HfApi()
info = api.model_info(model_name, revision=revision) info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames return filenames
def try_to_load_from_cache(model_name, revision, filename): def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache""" """Try to load a file from the Hugging Face cache"""
if revision is None: if revision is None:
revision = "main" revision = "main"
object_id = model_name.replace("/", "--") object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir(): if not repo_cache.is_dir():
...@@ -228,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename): ...@@ -228,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename):
return str(cached_file) if cached_file.is_file() else None return str(cached_file) if cached_file.is_file() else None
def weight_files(model_name, revision=None, extension=".safetensors"): def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames""" """Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None: if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension) filenames = weight_hub_files(model_id, revision, extension)
files = [] files = []
for filename in filenames: for filename in filenames:
cache_file = try_to_load_from_cache( cache_file = try_to_load_from_cache(
model_name, revision=revision, filename=filename model_id, revision=revision, filename=filename
) )
if cache_file is None: if cache_file is None:
raise LocalEntryNotFoundError( raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in " f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_name}` first." f"Please run `text-generation-server download-weights {model_id}` first."
) )
files.append(cache_file) files.append(cache_file)
return files return files
def download_weights(model_name, revision=None, extension=".safetensors"): def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub""" """Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None: if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension) filenames = weight_hub_files(model_id, revision, extension)
download_function = partial( download_function = partial(
hf_hub_download, hf_hub_download,
repo_id=model_name, repo_id=model_id,
local_files_only=False, local_files_only=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment