lib.rs 3.21 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
//! Text Generation gRPC client library
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2

xuxzh1's avatar
last  
xuxzh1 committed
3
4
5
6
7
pub mod v2;
pub mod v3;

use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
use thiserror::Error;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
use tonic::transport;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
10
11
use tonic::Status;

xuxzh1's avatar
last  
xuxzh1 committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
pub use v3::{Chunk, Image, Input, InputChunk};

#[async_trait]
pub trait Health {
    /// Check if a generate server is healthy by asking it to allocate a tensor on device
    async fn device_health(&self) -> Result<()>;

    /// Check if a generate server is healthy by doing a forward pass.
    /// EXPENSIVE
    async fn model_health(&self) -> Result<()>;
}

#[derive(Debug)]
pub struct ShardInfo {
    pub requires_padding: bool,
    pub dtype: String,
    pub device_type: String,
    pub window_size: Option<u32>,
    pub speculate: u32,
}

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33
#[derive(Error, Debug, Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
34
pub enum ClientError {
35
    #[error("Could not connect to Text Generation server: {0}")]
Olivier Dehaene's avatar
Olivier Dehaene committed
36
    Connection(String),
37
    #[error("Server error: {0}")]
Olivier Dehaene's avatar
Olivier Dehaene committed
38
    Generation(String),
39
40
    #[error("Sharded results are empty")]
    EmptyResults,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
43
44
}

impl From<Status> for ClientError {
    fn from(err: Status) -> Self {
45
46
47
        let err = Self::Generation(err.message().to_string());
        tracing::error!("{err}");
        err
Olivier Dehaene's avatar
Olivier Dehaene committed
48
49
50
51
52
    }
}

impl From<transport::Error> for ClientError {
    fn from(err: transport::Error) -> Self {
53
54
55
        let err = Self::Connection(err.to_string());
        tracing::error!("{err}");
        err
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
57
58
    }
}

xuxzh1's avatar
last  
xuxzh1 committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
    fn from(chunk: Chunk) -> Self {
        InputChunk { chunk: Some(chunk) }
    }
}

/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
    /// Convert chunks to string.
    fn chunks_to_string(&self) -> String;
}

impl ChunksToString for Vec<InputChunk> {
    fn chunks_to_string(&self) -> String {
        let mut output = String::new();
        self.iter().for_each(|c| match &c.chunk {
            Some(Chunk::Text(text)) => output.push_str(text),
            Some(Chunk::Image(Image { data, mimetype })) => {
                let encoded = STANDARD.encode(data);
                output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
            }
            // We don't create empty chunks, so this should be unreachable.
            None => unreachable!("Chunks should never be empty"),
        });
        output
    }
}

static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91
pub type Result<T> = std::result::Result<T, ClientError>;