test_auto_model.py 3.42 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from inspect import signature
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import patch

from transformers import AutoConfig
from transformers import AutoModelForCausalLM

from liger_kernel.transformers import AutoLigerKernelForCausalLM
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama


def test_auto_liger_kernel_for_causal_lm_from_pretrained():
    pretrained_model_name_or_path = "/path/to/llama/model"
    model_args = ("model_arg1", "model_arg2")

    original_kwargs = {
        "valid_arg_1": "some_value_1",
        "valid_arg_2": 10,
    }

    # These args should be passed through to apply_liger_kernel_to_llama fn
    apply_liger_kernel_kwargs = {
        "rope": False,
        "swiglu": True,
    }

    kwargs = {**original_kwargs, **apply_liger_kernel_kwargs}

    # Mock the model config instance returned from AutoConfig.from_pretrained()
    mock_model_config = MagicMock()
    mock_model_config.model_type = "llama"
    mock_llama = mock.Mock()

    with (
        patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}),
        mock.patch.object(AutoConfig, "from_pretrained", return_value=mock_model_config),
        mock.patch.object(
            AutoModelForCausalLM, "from_pretrained", return_value="mock_model"
        ) as mock_super_from_pretrained,
    ):
        # Mock the function signature of apply_liger_kernel_to_llama
        mock_llama.__signature__ = signature(apply_liger_kernel_to_llama)

        model = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

        # Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs
        mock_llama.assert_called_once_with(rope=False, swiglu=True)
        # Check that the original kwargs are passed to super().from_pretrained
        mock_super_from_pretrained.assert_called_once_with(
            pretrained_model_name_or_path, *model_args, **original_kwargs
        )
        assert model == "mock_model"


def test_auto_liger_kernel_for_causal_lm_from_config():
    original_kwargs = {
        "valid_arg_1": "some_value_1",
        "valid_arg_2": 10,
    }

    # These args should be passed through to apply_liger_kernel_to_llama fn
    apply_liger_kernel_kwargs = {
        "rope": False,
        "swiglu": True,
    }

    kwargs = {**original_kwargs, **apply_liger_kernel_kwargs}

    # Mock the model config instance returned from AutoConfig.from_pretrained()
    mock_model_config = MagicMock()
    mock_model_config.model_type = "llama"
    mock_llama = mock.Mock()

    with (
        patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}),
        mock.patch.object(AutoModelForCausalLM, "from_config", return_value="mock_model") as mock_super_from_config,
    ):
        # Mock the function signature of apply_liger_kernel_to_llama
        mock_llama.__signature__ = signature(apply_liger_kernel_to_llama)

        model = AutoLigerKernelForCausalLM.from_config(mock_model_config, **kwargs)

        # Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs
        mock_llama.assert_called_once_with(rope=False, swiglu=True)
        # Check that the original kwargs are passed to super().from_pretrained
        mock_super_from_config.assert_called_once_with(mock_model_config, **original_kwargs)
        assert model == "mock_model"