Unverified Commit b1482d90 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.
parent 7b870e1e
...@@ -118,6 +118,6 @@ ...@@ -118,6 +118,6 @@
] ]
] ]
}, },
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }
] ]
\ No newline at end of file
...@@ -97,8 +97,8 @@ fn test_model( ...@@ -97,8 +97,8 @@ fn test_model(
launcher.terminate().unwrap(); launcher.terminate().unwrap();
launcher.wait().unwrap(); launcher.wait().unwrap();
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap(); let result: GeneratedText = res.unwrap().json().unwrap();
results.pop().unwrap() result
} }
fn read_json(name: &str) -> GeneratedText { fn read_json(name: &str) -> GeneratedText {
......
mod infer;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer;
mod queue; mod queue;
pub mod server; pub mod server;
mod validation; mod validation;
......
...@@ -125,10 +125,10 @@ async fn generate( ...@@ -125,10 +125,10 @@ async fn generate(
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Send response // Send response
let response = vec![GenerateResponse { let response = GenerateResponse {
generated_text: response.generated_text.text, generated_text: response.generated_text.text,
details, details,
}]; };
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }
......
...@@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) ...@@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
...@@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi( ...@@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTestTest" assert generations[1].generated_text.text == "TestTestTestTestTest"
assert ( assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
) )
...@@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi( ...@@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
...@@ -261,7 +261,7 @@ def test_batch_concatenate( ...@@ -261,7 +261,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTestTest" assert generations[2].generated_text.text == "TestTestTestTestTest"
assert ( assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
) )
...@@ -284,7 +284,7 @@ def test_batch_concatenate( ...@@ -284,7 +284,7 @@ def test_batch_concatenate(
assert len(generations) == 2 assert len(generations) == 2
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
...@@ -307,7 +307,7 @@ def test_batch_concatenate( ...@@ -307,7 +307,7 @@ def test_batch_concatenate(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
......
...@@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion( ...@@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
...@@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi( ...@@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[1].generated_text.text == "Test.java:784)" assert generations[1].generated_text.text == ".java:784)"
assert ( assert (
generations[1].request_id generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id == default_multi_requests_causal_lm_batch.requests[1].id
...@@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi( ...@@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert ( assert (
generations[0].request_id generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id == default_multi_requests_causal_lm_batch.requests[0].id
...@@ -255,7 +255,7 @@ def test_batch_concatenate( ...@@ -255,7 +255,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
assert generations[2].generated_text.text == "Test.java:784)" assert generations[2].generated_text.text == ".java:784)"
assert ( assert (
generations[2].request_id generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id == default_multi_requests_causal_lm_batch.requests[1].id
...@@ -277,7 +277,7 @@ def test_batch_concatenate( ...@@ -277,7 +277,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
...@@ -297,7 +297,7 @@ def test_batch_concatenate( ...@@ -297,7 +297,7 @@ def test_batch_concatenate(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert ( assert (
generations[0].request_id generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id == default_multi_requests_causal_lm_batch.requests[0].id
......
...@@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat ...@@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "def test_get_all_users_with_" assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id assert generations[0].request_id == batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
...@@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion( ...@@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value""" == """ineProperty(exports, "__esModule", { value"""
) )
assert generations[0].request_id == batch.requests[0].id assert generations[0].request_id == batch.requests[0].id
assert ( assert (
......
...@@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True ...@@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
def get_model( def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool model_name: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
if sharded: if sharded:
......
...@@ -360,11 +360,9 @@ class CausalLM(Model): ...@@ -360,11 +360,9 @@ class CausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
generated_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0] all_input_ids[-stopping_criteria.current_tokens :, 0]
) )
output_text = request.inputs + generated_text
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed seed = next_token_chooser.choice.seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment