Unverified Commit b1482d90 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.
parent 7b870e1e
......@@ -118,6 +118,6 @@
]
]
},
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
]
\ No newline at end of file
......@@ -97,8 +97,8 @@ fn test_model(
launcher.terminate().unwrap();
launcher.wait().unwrap();
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
results.pop().unwrap()
let result: GeneratedText = res.unwrap().json().unwrap();
result
}
fn read_json(name: &str) -> GeneratedText {
......
mod infer;
/// Text Generation Inference Webserver
mod infer;
mod queue;
pub mod server;
mod validation;
......
......@@ -125,10 +125,10 @@ async fn generate(
tracing::info!("Output: {}", response.generated_text.text);
// Send response
let response = vec![GenerateResponse {
let response = GenerateResponse {
generated_text: response.generated_text.text,
details,
}];
};
Ok((headers, Json(response)))
}
......
......@@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
......@@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
assert generations[1].generated_text.text == "TestTestTestTestTest"
assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
......@@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
......@@ -261,7 +261,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
assert generations[2].generated_text.text == "TestTestTestTestTest"
assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
......@@ -284,7 +284,7 @@ def test_batch_concatenate(
assert len(generations) == 2
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
......@@ -307,7 +307,7 @@ def test_batch_concatenate(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
......
......@@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
......@@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "Test.java:784)"
assert generations[1].generated_text.text == ".java:784)"
assert (
generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
......@@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
......@@ -255,7 +255,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "Test.java:784)"
assert generations[2].generated_text.text == ".java:784)"
assert (
generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
......@@ -277,7 +277,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
......@@ -297,7 +297,7 @@ def test_batch_concatenate(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
......
......@@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "def test_get_all_users_with_"
assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
......@@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
== """ineProperty(exports, "__esModule", { value"""
)
assert generations[0].request_id == batch.requests[0].id
assert (
......
......@@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, revision=revision)
if config.model_type == "bloom":
if sharded:
......
......@@ -360,11 +360,9 @@ class CausalLM(Model):
if stop:
# Decode generated tokens
generated_text = self.decode(
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
output_text = request.inputs + generated_text
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment