embedding.rs 2.83 KB
Newer Older
1
use serde_json::{from_str, json, to_string};
2
use sglang_router_rs::protocols::{common::GenerationRequest, embedding::EmbeddingRequest};
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

#[test]
fn test_embedding_request_serialization_string_input() {
    let req = EmbeddingRequest {
        model: "test-emb".to_string(),
        input: json!("hello"),
        encoding_format: Some("float".to_string()),
        user: Some("user-1".to_string()),
        dimensions: Some(128),
        rid: Some("rid-123".to_string()),
    };

    let serialized = to_string(&req).unwrap();
    let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();

    assert_eq!(deserialized.model, req.model);
    assert_eq!(deserialized.input, req.input);
    assert_eq!(deserialized.encoding_format, req.encoding_format);
    assert_eq!(deserialized.user, req.user);
    assert_eq!(deserialized.dimensions, req.dimensions);
    assert_eq!(deserialized.rid, req.rid);
}

#[test]
fn test_embedding_request_serialization_array_input() {
    let req = EmbeddingRequest {
        model: "test-emb".to_string(),
        input: json!(["a", "b", "c"]),
        encoding_format: None,
        user: None,
        dimensions: None,
        rid: None,
    };

    let serialized = to_string(&req).unwrap();
    let de: EmbeddingRequest = from_str(&serialized).unwrap();
    assert_eq!(de.model, req.model);
    assert_eq!(de.input, req.input);
}

#[test]
fn test_embedding_generation_request_trait_string() {
    let req = EmbeddingRequest {
        model: "emb-model".to_string(),
        input: json!("hello"),
        encoding_format: None,
        user: None,
        dimensions: None,
        rid: None,
    };
    assert!(!req.is_stream());
    assert_eq!(req.get_model(), Some("emb-model"));
    assert_eq!(req.extract_text_for_routing(), "hello");
}

#[test]
fn test_embedding_generation_request_trait_array() {
    let req = EmbeddingRequest {
        model: "emb-model".to_string(),
        input: json!(["hello", "world"]),
        encoding_format: None,
        user: None,
        dimensions: None,
        rid: None,
    };
    assert_eq!(req.extract_text_for_routing(), "hello world");
}

#[test]
fn test_embedding_generation_request_trait_non_text() {
    let req = EmbeddingRequest {
        model: "emb-model".to_string(),
        input: json!({"tokens": [1, 2, 3]}),
        encoding_format: None,
        user: None,
        dimensions: None,
        rid: None,
    };
    assert_eq!(req.extract_text_for_routing(), "");
}

#[test]
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
    let req = EmbeddingRequest {
        model: "emb-model".to_string(),
        input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
        encoding_format: None,
        user: None,
        dimensions: None,
        rid: None,
    };
    // Only top-level string elements are extracted
    assert_eq!(req.extract_text_for_routing(), "a");
}