integration_tests.rs 4.53 KB
Newer Older
1
2
use float_eq::assert_float_eq;
use serde::Deserialize;
3
use serde_json::Value;
4
use std::fs::File;
5
6
7
8
9
10
11
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection};

12
13
14
15
16
#[derive(Deserialize)]
pub struct Token {
    id: u32,
    text: String,
    logprob: Option<f32>,
17
    special: bool,
18
19
}

20
21
22
23
#[derive(Deserialize)]
struct Details {
    finish_reason: String,
    generated_tokens: u32,
24
    tokens: Vec<Token>,
25
26
27
28
29
30
31
32
}

#[derive(Deserialize)]
struct GeneratedText {
    generated_text: String,
    details: Details,
}

33
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
34
35
    let argv = vec![
        "text-generation-launcher".to_string(),
36
37
        "--model-id".to_string(),
        model_id.clone(),
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        "--num-shard".to_string(),
        num_shard.to_string(),
        "--port".to_string(),
        port.to_string(),
        "--master-port".to_string(),
        master_port.to_string(),
        "--shard-uds-path".to_string(),
        format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
    ];

    let mut launcher = Popen::create(
        &argv,
        PopenConfig {
            stdout: Redirection::Pipe,
52
            stderr: Redirection::Merge,
53
54
55
            ..Default::default()
        },
    )
56
    .expect("Could not start launcher");
57
58

    // Redirect STDOUT and STDERR to the console
59
    // (STDERR is merged into STDOUT)
60
61
62
63
64
65
66
67
68
    let launcher_stdout = launcher.stdout.take().unwrap();

    thread::spawn(move || {
        let stdout = BufReader::new(launcher_stdout);
        for line in stdout.lines() {
            println!("{}", line.unwrap());
        }
    });

69
    for _ in 0..60 {
70
71
72
73
74
75
76
77
78
        let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
        if health.is_ok() {
            return launcher;
        }
        sleep(Duration::from_secs(2));
    }

    launcher.terminate().unwrap();
    launcher.wait().unwrap();
79
    panic!("failed to launch {}", model_id)
80
81
}

82
fn test_model(
83
    model_id: String,
84
85
86
87
    num_shard: usize,
    port: usize,
    master_port: usize,
) -> GeneratedText {
88
    let mut launcher = start_launcher(model_id, num_shard, port, master_port);
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    let data = r#"
        {
            "inputs": "Test request",
            "parameters": {
                "details": true
            }
        }"#;
    let req: Value = serde_json::from_str(data).unwrap();

    let client = reqwest::blocking::Client::new();
    let res = client
        .post(format!("http://localhost:{}/generate", port))
        .json(&req)
        .send();

    launcher.terminate().unwrap();
    launcher.wait().unwrap();

108
109
    let result: GeneratedText = res.unwrap().json().unwrap();
    result
110
111
112
113
114
115
116
117
118
119
}

fn read_json(name: &str) -> GeneratedText {
    let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    d.push("tests/");
    d.push(name);

    let file = File::open(d).unwrap();
    let reader = BufReader::new(file);

120
121
    let result: GeneratedText = serde_json::from_reader(reader).unwrap();
    result
122
123
124
125
126
}

fn compare_results(result: GeneratedText, expected: GeneratedText) {
    assert_eq!(result.generated_text, expected.generated_text);
    assert_eq!(result.details.finish_reason, expected.details.finish_reason);
127
128
129
130
131
132
133
134
135
136
137
    assert_eq!(
        result.details.generated_tokens,
        expected.details.generated_tokens
    );

    for (token, expected_token) in result
        .details
        .tokens
        .into_iter()
        .zip(expected.details.tokens.into_iter())
    {
138
139
        assert_eq!(token.id, expected_token.id);
        assert_eq!(token.text, expected_token.text);
140
        assert_eq!(token.special, expected_token.special);
141
142
        if let Some(logprob) = token.logprob {
            let expected_logprob = expected_token.logprob.unwrap();
143
144
            assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
        } else {
145
            assert_eq!(token.logprob, expected_token.logprob);
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        }
    }
}

#[test]
fn test_bloom_560m() {
    let expected = read_json("bloom_560m.json");

    let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
    compare_results(result, expected);
}

#[test]
fn test_bloom_560m_distributed() {
    let expected = read_json("bloom_560m.json");

    let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
    compare_results(result, expected);
}

#[test]
fn test_mt0_base() {
    let expected = read_json("mt0_base.json");

    let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
    compare_results(result, expected);
}