health.rs 2.59 KB
Newer Older
1
2
3
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
4
    Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
5
};
6
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
7

8
9
10
11
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
const BATCH_ID: u64 = u64::MAX;

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#[derive(Clone, Debug)]
pub(crate) struct Health {
    client: ShardedClient,
    generation_health: Arc<AtomicBool>,
}

impl Health {
    pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
        Self {
            client,
            generation_health,
        }
    }

    pub(crate) async fn check(&mut self) -> bool {
        if self.generation_health.load(Ordering::SeqCst) {
            // Generation is healthy, we only check that the shards are answering gRPC calls
            self.client.health().await.is_ok()
        } else {
            // Generation is unhealthy or have not sent any generation request yet

            // Dummy batch of 1 token and 1 generated token
            let liveness_request = Request {
35
                id: LIVENESS_ID,
36
37
38
                input_chunks: Some(Input {
                    chunks: vec![Chunk::Text("liveness".into()).into()],
                }),
39
40
                inputs: "liveness".to_string(),
                truncate: 10,
41
                prefill_logprobs: false,
42
43
44
45
46
47
48
49
                parameters: Some(NextTokenChooserParameters {
                    temperature: 1.0,
                    top_k: 0,
                    top_p: 1.0,
                    typical_p: 1.0,
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 1.0,
50
                    frequency_penalty: 0.0,
51
                    watermark: false,
drbh's avatar
drbh committed
52
53
                    grammar: String::new(),
                    grammar_type: ProtoGrammarType::None as i32,
54
55
56
57
58
59
                }),
                stopping_parameters: Some(StoppingCriteriaParameters {
                    max_new_tokens: 1,
                    stop_sequences: vec![],
                    ignore_eos_token: false,
                }),
Nicolas Patry's avatar
Nicolas Patry committed
60
                top_n_tokens: 0,
61
62
            };
            let batch = Batch {
63
                id: BATCH_ID,
64
65
66
67
68
69
70
71
72
73
74
75
                requests: vec![liveness_request],
                size: 1,
                max_tokens: 2,
            };
            // Skips the queue
            let value = self.client.prefill(batch).await.is_ok();
            // Update generation health
            self.generation_health.store(value, Ordering::SeqCst);
            value
        }
    }
}