flash_mixtral.py 900 Bytes
Newer Older
OlivierDehaene's avatar
OlivierDehaene committed
1
2
3
4
5
import torch

from typing import Optional

from text_generation_server.models.flash_mistral import BaseFlashMistral
OlivierDehaene's avatar
OlivierDehaene committed
6
7
8
9
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
    MixtralConfig,
    FlashMixtralForCausalLM,
)
OlivierDehaene's avatar
OlivierDehaene committed
10
11
12
13


class FlashMixtral(BaseFlashMistral):
    def __init__(
OlivierDehaene's avatar
OlivierDehaene committed
14
15
16
17
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
18
        use_medusa: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
19
20
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
OlivierDehaene's avatar
OlivierDehaene committed
21
22
23
24
25
26
27
    ):
        super(FlashMixtral, self).__init__(
            config_cls=MixtralConfig,
            model_cls=FlashMixtralForCausalLM,
            model_id=model_id,
            revision=revision,
            quantize=quantize,
28
            use_medusa=use_medusa,
OlivierDehaene's avatar
OlivierDehaene committed
29
            dtype=dtype,
OlivierDehaene's avatar
OlivierDehaene committed
30
            trust_remote_code=trust_remote_code,
OlivierDehaene's avatar
OlivierDehaene committed
31
        )