"vscode:/vscode.git/clone" did not exist on "d9942bae249329bd8c8bf5c92f0f108595fcb84f"
llava_next.py 1.35 KB
Newer Older
1
2
import torch

Nicolas Patry's avatar
Nicolas Patry committed
3
from typing import Optional, Tuple
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

from transformers import (
    AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
    LlavaNextForConditionalGeneration,
)

from text_generation_server.models.vlm_causal_lm import VlmCausalLM


class LlavaNext(VlmCausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
21
        speculator: Optional[str] = None,
22
23
24
25
26
27
28
29
30
31
32
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        self.processor = AutoProcessor.from_pretrained(
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        super().__init__(
            model_cls=LlavaNextForConditionalGeneration,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
Nicolas Patry's avatar
Nicolas Patry committed
33
            speculator=speculator,
34
35
36
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )
Nicolas Patry's avatar
Nicolas Patry committed
37
38
39
40
41
42
43
44
45
46

    def get_layer_config(self, model) -> Tuple[int, int, int]:
        return (
            len(model.language_model.model.layers),
            model.language_model.model.num_key_value_heads,
            model.language_model.model.head_size,
        )

    def max_past(self) -> Optional[int]:
        return getattr(self.model.language_model, "max_past", None)