eagle.py 3.15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import os
from typing import Optional, Union

from transformers import AutoConfig, PretrainedConfig

9
10
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config

11
12
13
14

class EAGLEConfig(PretrainedConfig):
    model_type = "eagle"

15
16
17
18
19
20
21
    def __init__(
        self,
        model: Union[PretrainedConfig, dict, None] = None,
        truncated_vocab_size: Optional[int] = None,
        method: Optional[str] = "eagle",
        **kwargs,
    ):
22
23
24
25
26
27
28
29
30
31
32
        model_config: Union[PretrainedConfig, DeepseekV2Config, None]
        if isinstance(model, dict):
            archs = model.get("architectures", [])
            target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
            if any(target_arch in archs for target_arch in target_archs):
                # AutoConfig does not support DeepSeek MoE models yet
                model_config = DeepseekV2Config(**model)
            else:
                model_config = AutoConfig.for_model(**model)
        else:
            model_config = model
33
34

        for k, v in kwargs.items():
35
            if k != "architectures" and k != "model_type" and hasattr(model_config, k):
36
37
38
39
40
41
42
                setattr(model_config, k, v)

        self.model = model_config

        if self.model is None:
            self.truncated_vocab_size = None
        else:
43
44
45
46
47
            self.truncated_vocab_size = (
                self.model.vocab_size
                if truncated_vocab_size is None
                else truncated_vocab_size
            )
48

49
50
        # Eagle model name should follow naming convention of
        # LlamaForCausalLM -> EagleLlamaForCausalLM
51
        # LlamaForCausalLM -> Eagle3LlamaForCausalLM
52
        # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
53
        if method == "eagle":
54
            assert self.model is not None, (
55
                "model should not be None when method is eagle"
56
            )
57
            kwargs["architectures"] = [
58
59
                f"Eagle{arch}" if not arch.startswith("Eagle") else arch
                for arch in self.model.architectures
60
            ]
61

62
        elif method == "eagle3":
63
            assert self.model is not None, (
64
                "model should not be None when method is eagle3"
65
            )
66
            kwargs["architectures"] = [
67
68
69
70
                arch
                if arch.startswith("Eagle3") or arch.endswith("Eagle3")
                else f"Eagle3{arch}"
                for arch in self.model.architectures
71
            ]
72
        else:
73
74
75
            raise ValueError(
                f"Invalid method {method}. Supported methods are eagle and eagle3."
            )
76
77
78
79
80

        super().__init__(**kwargs)

        if self.model is not None:
            for k, v in self.model.to_dict().items():
81
                if k not in kwargs:
82
                    setattr(self, k, v)
83
84
85
86
87
88
89
90

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        **kwargs,
    ) -> "EAGLEConfig":
        config_dict, kwargs = cls.get_config_dict(
91
92
            pretrained_model_name_or_path, **kwargs
        )
93
        return cls.from_dict(config_dict, **kwargs)