model_card.rs 2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use dynamo_llm::model_card::{ModelDeploymentCard, PromptFormatterArtifact, TokenizerKind};
5
use tempfile::tempdir;
Biswa Panda's avatar
Biswa Panda committed
6
7

const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
8
9
10

#[tokio::test]
async fn test_model_info_from_hf_like_local_repo() {
11
    let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
12
    let info = mdc.model_info.unwrap().get_model_info().unwrap();
13
    assert_eq!(info.model_type(), "llama");
Biswa Panda's avatar
Biswa Panda committed
14
15
    assert_eq!(info.bos_token_id(), 1);
    assert_eq!(info.eos_token_ids(), vec![2]);
16
17
    assert_eq!(info.max_position_embeddings(), Some(2048));
    assert_eq!(info.vocab_size(), Some(32000));
18
19
20
21
22
}

#[tokio::test]
async fn test_model_info_from_non_existent_local_repo() {
    let path = "tests/data/sample-models/this-model-does-not-exist";
23
    let result = ModelDeploymentCard::load_from_disk(path, None);
24
25
26
27
28
    assert!(result.is_err());
}

#[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() {
29
    let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
30
    // Verify tokenizer file was found
31
    match mdc.tokenizer.unwrap() {
32
33
34
35
36
37
        TokenizerKind::HfTokenizerJson(_) => (),
    }
}

#[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() {
38
    let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
39
40
41
42
43
44
45
46
47
48
49
    // Verify prompt formatter was found
    match mdc.prompt_formatter {
        Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
        _ => panic!("Expected HfTokenizerConfigJson prompt formatter"),
    }
}

#[tokio::test]
async fn test_missing_required_files() {
    // Create empty temp directory
    let temp_dir = tempdir().unwrap();
50
    let result = ModelDeploymentCard::load_from_disk(temp_dir.path(), None);
51
52
53
54
    assert!(result.is_err());
    let err = result.unwrap_err().to_string();
    // Should fail because config.json is missing
    assert!(err.contains("unable to extract"));
55
}