lib.rs 3.21 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
//! Text Generation gRPC client library
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2

OlivierDehaene's avatar
OlivierDehaene committed
3
4
pub mod v2;
pub mod v3;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5

OlivierDehaene's avatar
OlivierDehaene committed
6
use async_trait::async_trait;
7
use base64::{engine::general_purpose::STANDARD, Engine};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
use thiserror::Error;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use tonic::transport;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
10
11
use tonic::Status;

OlivierDehaene's avatar
OlivierDehaene committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
pub use v3::{Chunk, Image, Input, InputChunk};

#[async_trait]
pub trait Health {
    /// Check if a generate server is healthy by asking it to allocate a tensor on device
    async fn device_health(&self) -> Result<()>;

    /// Check if a generate server is healthy by doing a forward pass.
    /// EXPENSIVE
    async fn model_health(&self) -> Result<()>;
}

#[derive(Debug)]
pub struct ShardInfo {
    pub requires_padding: bool,
    pub dtype: String,
    pub device_type: String,
    pub window_size: Option<u32>,
    pub speculate: u32,
}

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33
#[derive(Error, Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
34
pub enum ClientError {
35
    #[error("Could not connect to Text Generation server: {0}")]
Olivier Dehaene's avatar
Olivier Dehaene committed
36
    Connection(String),
37
    #[error("Server error: {0}")]
Olivier Dehaene's avatar
Olivier Dehaene committed
38
    Generation(String),
39
40
    #[error("Sharded results are empty")]
    EmptyResults,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
43
44
}

impl From<Status> for ClientError {
    fn from(err: Status) -> Self {
45
46
47
        let err = Self::Generation(err.message().to_string());
        tracing::error!("{err}");
        err
Olivier Dehaene's avatar
Olivier Dehaene committed
48
49
50
51
52
    }
}

impl From<transport::Error> for ClientError {
    fn from(err: transport::Error) -> Self {
53
54
55
        let err = Self::Connection(err.to_string());
        tracing::error!("{err}");
        err
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
57
58
    }
}

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
    fn from(chunk: Chunk) -> Self {
        InputChunk { chunk: Some(chunk) }
    }
}

/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
    /// Convert chunks to string.
    fn chunks_to_string(&self) -> String;
}

impl ChunksToString for Vec<InputChunk> {
    fn chunks_to_string(&self) -> String {
        let mut output = String::new();
        self.iter().for_each(|c| match &c.chunk {
            Some(Chunk::Text(text)) => output.push_str(text),
            Some(Chunk::Image(Image { data, mimetype })) => {
                let encoded = STANDARD.encode(data);
                output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
            }
            // We don't create empty chunks, so this should be unreachable.
            None => unreachable!("Chunks should never be empty"),
        });
        output
    }
}
OlivierDehaene's avatar
OlivierDehaene committed
88
89
90
91

static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";

pub type Result<T> = std::result::Result<T, ClientError>;