test_get_model.py 1.14 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os

import pytest

from lmdeploy.turbomind.utils import get_model_from_config


@pytest.mark.parametrize('item',
                         [('baichuan-inc/Baichuan-7B', 'baichuan'),
                          ('baichuan-inc/Baichuan2-7B-Base', 'baichuan2'),
                          ('internlm/internlm2-7b', 'internlm2'),
                          ('internlm/internlm2-chat-7b', 'internlm2'),
                          ('internlm/internlm2-math-20b', 'internlm2'),
                          ('internlm/internlm-20b', 'llama'),
                          ('NousResearch/Llama-2-7b-chat-hf', 'llama'),
                          ('Qwen/Qwen-7B-Chat', 'qwen'),
                          ('Qwen/Qwen-14B', 'qwen'),
                          ('NousResearch/Nous-Hermes-2-SOLAR-10.7B', 'llama'),
                          ('01-ai/Yi-34B-Chat', 'llama')])
def test_get_model_from_config(item):
    from transformers.utils import cached_file
    model_id, result = item
    local_file = cached_file(model_id, 'config.json')
    local_dir = os.path.dirname(local_file)
    print(get_model_from_config(local_dir))
    assert get_model_from_config(local_dir) == result