eagle.py 3.04 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import os

6
from transformers import AutoConfig, DeepseekV2Config, PretrainedConfig
7

8
9
from vllm.transformers_utils.utils import without_trust_remote_code

10
11
12
13

class EAGLEConfig(PretrainedConfig):
    model_type = "eagle"

14
15
    def __init__(
        self,
16
17
18
        model: PretrainedConfig | dict | None = None,
        truncated_vocab_size: int | None = None,
        method: str | None = "eagle",
19
20
        **kwargs,
    ):
21
        model_config: PretrainedConfig | DeepseekV2Config | None
22
        if isinstance(model, dict):
23
            model_config = AutoConfig.for_model(**model)
24
25
        else:
            model_config = model
26
27

        for k, v in kwargs.items():
28
            if k != "architectures" and k != "model_type" and hasattr(model_config, k):
29
30
31
32
33
34
35
                setattr(model_config, k, v)

        self.model = model_config

        if self.model is None:
            self.truncated_vocab_size = None
        else:
36
37
38
39
40
            self.truncated_vocab_size = (
                self.model.vocab_size
                if truncated_vocab_size is None
                else truncated_vocab_size
            )
41

42
43
        # Eagle model name should follow naming convention of
        # LlamaForCausalLM -> EagleLlamaForCausalLM
44
        # LlamaForCausalLM -> Eagle3LlamaForCausalLM
45
        # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
46
        if method == "eagle":
47
            assert self.model is not None, (
48
                "model should not be None when method is eagle"
49
            )
50
            kwargs["architectures"] = [
51
52
                f"Eagle{arch}" if not arch.startswith("Eagle") else arch
                for arch in self.model.architectures
53
            ]
54

55
        elif method == "eagle3":
56
            assert self.model is not None, (
57
                "model should not be None when method is eagle3"
58
            )
59
            kwargs["architectures"] = [
60
61
62
63
                arch
                if arch.startswith("Eagle3") or arch.endswith("Eagle3")
                else f"Eagle3{arch}"
                for arch in self.model.architectures
64
            ]
65
        else:
66
67
68
            raise ValueError(
                f"Invalid method {method}. Supported methods are eagle and eagle3."
            )
69
70
71
72
73

        super().__init__(**kwargs)

        if self.model is not None:
            for k, v in self.model.to_dict().items():
74
                if k not in kwargs:
75
                    setattr(self, k, v)
76
77
78
79

    @classmethod
    def from_pretrained(
        cls,
80
        pretrained_model_name_or_path: str | os.PathLike,
81
82
83
        **kwargs,
    ) -> "EAGLEConfig":
        config_dict, kwargs = cls.get_config_dict(
84
            pretrained_model_name_or_path, **without_trust_remote_code(kwargs)
85
        )
86
        return cls.from_dict(config_dict, **kwargs)
87
88
89
90
91
92

    def to_json_string(self, use_diff: bool = True) -> str:
        # we override use_diff to False as initializing
        # EAGLEConfig with default arguments is not supported
        del use_diff
        return super().to_json_string(use_diff=False)