"tasks/orqa/unsupervised/tokenizers.py" did not exist on "1a2cb60c5ca7456f4971772c66a66e2a80f99232"
mamba_lm.py 4.42 KB
Newer Older
1
2
3
4
from typing import Optional, Union

import torch

5
import lm_eval.models.utils
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM


@register_model("mamba_ssm")
class MambaLMWrapper(HFLM):
    def __init__(
        self,
        pretrained="state-spaces/mamba-130m",
        **kwargs,
    ) -> None:
        """
        Mamba (via the `mamba_ssm` package) supports the following args:
        ```
        d_model: int,
        n_layer: int,
        vocab_size: int,
        initializer_cfg=None,
        pad_vocab_size_multiple: int = 1,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        ```

        See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info.
        The above can all be passed via `--model_args` or to this __init__() directly
        but we recommend placing many of these within the config.json file uploaded alongside your
        Mamba model to the HF Hub instead.
        All other HuggingFace from_pretrained() kwargs
        such as those related to
        `parallelize=True`, PEFT, autoGPTQ,
        or any sub-configurations of these advanced args,
        are unsupported by the `mamba_ssm` package.

        The HFLM arguments

45
        `backend`, `tokenizer`, `truncation`, `max_length`,
46
47
48
49
50
51
52
53
54
55
56
57
58
        `device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`

        Are all supported by Mamba where they do not conflict
        with Mamba-specific restrictions such as causal LMs only.
        """

        if "backend" in kwargs:
            # mamba currently only supports causal models
            assert kwargs["backend"] == "causal"

        super().__init__(
            pretrained=pretrained,
            # set appropriate defaults for tokenizer, max length, etc
59
60
61
            backend=kwargs.pop("backend", "causal"),
            tokenizer=kwargs.pop("tokenizer", "EleutherAI/gpt-neox-20b"),
            max_length=kwargs.pop("max_length", 2048),
62
63
64
65
66
67
68
69
70
71
            **kwargs,
        )

    def _get_config(
        self,
        pretrained: str,
        **kwargs,
    ) -> None:
        try:
            from mamba_ssm.utils.hf import load_config_hf  # noqa: F811
72
73
        except ModuleNotFoundError as exception:
            raise type(exception)(
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
            )

        self._config = load_config_hf(pretrained)

    def _create_model(
        self,
        pretrained: str,
        dtype: Optional[Union[str, torch.dtype]] = "float16",
        # no `parallelize=True` options
        # no PEFT and quantization options
        # Mamba does not support arbitrary HF from_pretrained() args
        **kwargs,
    ) -> None:
        try:
            from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel  # noqa: F811
91
92
        except ModuleNotFoundError as exception:
            raise type(exception)(
93
94
95
96
97
98
99
                "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
            )

        self._model = MambaLMHeadModel.from_pretrained(
            pretrained,
            device=self._device,
100
101
102
            dtype=torch.float16
            if dtype == "auto"
            else lm_eval.models.utils.get_dtype(dtype),
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        )

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        for key in ("do_sample", "attention_mask"):
            if key in generation_kwargs:
                generation_kwargs.pop(key)

        # mamba's custom GenerationMixin currently does not support
        # passing stopping criteria.
        # for the time being, we simply generate to max length,
        # then truncate (equivalent result)
        # -- this should be revisited to speed up generation
        # stopping_criteria = stop_sequences_criteria(
        #     self.tokenizer, stop, 1, context.shape[0]
        # )

        return self.model.generate(
            input_ids=context,
            max_length=max_length,
            # stopping_criteria=stopping_criteria,
            # pad_token_id=self.tokenizer.pad_token_id,
            # use_cache=True,
            **generation_kwargs,
        )