flash_qwen2.py 2.71 KB
Newer Older
OlivierDehaene's avatar
OlivierDehaene committed
1
2
3
4
5
6
import math

import torch
import torch.distributed

from opentelemetry import trace
Nicolas Patry's avatar
Nicolas Patry committed
7
from transformers import AutoTokenizer, AutoConfig
OlivierDehaene's avatar
OlivierDehaene committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Optional

from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
    BaseFlashMistral,
    set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
    Qwen2ForCausalLM,
)
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
    Weights,
)

tracer = trace.get_tracer(__name__)


class FlashQwen2(BaseFlashMistral):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
Nicolas Patry's avatar
Nicolas Patry committed
33
        speculator: Optional[str] = None,
OlivierDehaene's avatar
OlivierDehaene committed
34
35
36
37
38
39
40
41
42
43
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        self.process_group, rank, world_size = initialize_torch_distributed()
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            dtype = torch.float16 if dtype is None else dtype
        else:
            raise NotImplementedError("FlashQwen2 is only available on GPU")

Nicolas Patry's avatar
Nicolas Patry committed
44
        tokenizer = AutoTokenizer.from_pretrained(
OlivierDehaene's avatar
OlivierDehaene committed
45
46
47
48
49
50
51
            model_id,
            revision=revision,
            padding_side="left",
            truncation_side="left",
            trust_remote_code=trust_remote_code,
        )

Nicolas Patry's avatar
Nicolas Patry committed
52
        config = AutoConfig.from_pretrained(
OlivierDehaene's avatar
OlivierDehaene committed
53
54
55
            model_id, revision=revision, trust_remote_code=trust_remote_code
        )
        config.quantize = quantize
Nicolas Patry's avatar
Nicolas Patry committed
56
        config.speculator = speculator
OlivierDehaene's avatar
OlivierDehaene committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

        # Set context windows
        if config.sliding_window is not None:
            set_sliding_window(
                config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
            )

        torch.distributed.barrier(group=self.process_group)

        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
        weights = Weights(filenames, device, dtype, process_group=self.process_group)
        if config.quantize in ["gptq", "awq"]:
            weights._set_gptq_params(model_id, revision)

        model = Qwen2ForCausalLM(config, weights)

        self.cuda_graphs = {}

        torch.distributed.barrier(group=self.process_group)
        super(BaseFlashMistral, self).__init__(
            model=model,
            tokenizer=tokenizer,
            num_layers=len(model.model.layers),
            num_kv_heads=model.model.num_key_value_heads,
            head_size=model.model.head_size,
            dtype=dtype,
            device=device,
            rank=rank,
            world_size=world_size,
            sliding_window=config.sliding_window,
        )