model_card.rs 2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use dynamo_llm::model_card::{ModelDeploymentCard, PromptFormatterArtifact, TokenizerKind};
5
use tempfile::tempdir;
Biswa Panda's avatar
Biswa Panda committed
6
7

const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
8
9
10

#[tokio::test]
async fn test_model_info_from_hf_like_local_repo() {
11
    let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap();
12
    let info = mdc.model_info.unwrap().get_model_info().await.unwrap();
13
    assert_eq!(info.model_type(), "llama");
Biswa Panda's avatar
Biswa Panda committed
14
15
    assert_eq!(info.bos_token_id(), 1);
    assert_eq!(info.eos_token_ids(), vec![2]);
16
17
    assert_eq!(info.max_position_embeddings(), Some(2048));
    assert_eq!(info.vocab_size(), Some(32000));
18
19
20
21
22
}

#[tokio::test]
async fn test_model_info_from_non_existent_local_repo() {
    let path = "tests/data/sample-models/this-model-does-not-exist";
23
    let result = ModelDeploymentCard::load(path).await;
24
25
26
27
28
    assert!(result.is_err());
}

#[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() {
29
    let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap();
30
    // Verify tokenizer file was found
31
    match mdc.tokenizer.unwrap() {
32
        TokenizerKind::HfTokenizerJson(_) => (),
33
        TokenizerKind::GGUF(_) => (),
34
35
36
37
38
    }
}

#[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() {
39
    let mdc = ModelDeploymentCard::load(HF_PATH).await.unwrap();
40
41
42
43
44
45
46
47
48
49
50
    // Verify prompt formatter was found
    match mdc.prompt_formatter {
        Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
        _ => panic!("Expected HfTokenizerConfigJson prompt formatter"),
    }
}

#[tokio::test]
async fn test_missing_required_files() {
    // Create empty temp directory
    let temp_dir = tempdir().unwrap();
51
    let result = ModelDeploymentCard::load(temp_dir.path()).await;
52
53
54
55
    assert!(result.is_err());
    let err = result.unwrap_err().to_string();
    // Should fail because config.json is missing
    assert!(err.contains("unable to extract"));
56
}