"R-package/vscode:/vscode.git/clone" did not exist on "b765fa6efb9372e93292861290dd81428cc7d63a"
auto_model.py 2.35 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import inspect
import logging

from transformers import AutoConfig
from transformers import AutoModelForCausalLM

from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel

logger = logging.getLogger(__name__)


def _get_model_config(model_dir, **model_init_kwargs):
    config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
    return config


class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
    """
    This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
    if applicable.
    """

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)

        # Determine the model type and apply the Liger Kernel if applicable
        # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
        model_type = model_config.model_type

        _apply_liger_kernel(model_type, **kwargs)

        # Filter out kwargs that were passed to the apply_liger_* function, which will cause
        # model initialization errors otherwise
        apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
        apply_fn_signature = inspect.signature(apply_fn)

        applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}

        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)

    @classmethod
    def from_config(cls, config, **kwargs):
        model_type = getattr(config, "model_type", None)
        if not model_type:
            logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
            return
        model_type = config.model_type

        _apply_liger_kernel(model_type, **kwargs)

        # Filter out kwargs that were passed to the apply_liger_* function, which will cause
        # model initialization errors otherwise
        apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
        apply_fn_signature = inspect.signature(apply_fn)
        applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}

        return super().from_config(config, **applicable_kwargs)