flash_llama.py 3.93 KB
Newer Older
1
2
3
4
import torch
import torch.distributed

from opentelemetry import trace
5
6
from transformers import AutoConfig, AutoTokenizer
from transformers.models.llama import LlamaTokenizer
7
from typing import Optional
8
9
10
11

from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
    FlashLlamaForCausalLM,
12
    LlamaConfig,
13
14
15
16
)
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
17
    Weights,
18
19
20
21
22
23
24
)

tracer = trace.get_tracer(__name__)


class FlashLlama(FlashCausalLM):
    def __init__(
25
26
27
28
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
29
        dtype: Optional[torch.dtype] = None,
30
        trust_remote_code: bool = False,
Nicolas Patry's avatar
Nicolas Patry committed
31
        use_medusa: Optional[str] = None,
32
    ):
33
        self.process_group, rank, world_size = initialize_torch_distributed()
34
        if torch.cuda.is_available():
35
            device = torch.device(f"cuda:{rank}")
36
            dtype = torch.float16 if dtype is None else dtype
37
38
39
        else:
            raise NotImplementedError("FlashLlama is only available on GPU")

40
41
42
43
44
45
46
47
48
        try:
            tokenizer = LlamaTokenizer.from_pretrained(
                model_id,
                revision=revision,
                padding_side="left",
                truncation_side="left",
                trust_remote_code=trust_remote_code,
            )
        except Exception:
49
            tokenizer = AutoTokenizer.from_pretrained(
50
51
52
53
54
55
                model_id,
                revision=revision,
                padding_side="left",
                truncation_side="left",
                trust_remote_code=trust_remote_code,
            )
56

57
        config = LlamaConfig.from_pretrained(
58
            model_id, revision=revision, trust_remote_code=trust_remote_code
59
        )
60
        config.quantize = quantize
61
62

        torch.distributed.barrier(group=self.process_group)
63

64
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")
65
        weights = Weights(filenames, device, dtype, process_group=self.process_group)
66
        if config.quantize in ["gptq", "awq"]:
OlivierDehaene's avatar
OlivierDehaene committed
67
            weights._set_gptq_params(model_id, revision)
68

69
        model = FlashLlamaForCausalLM(config, weights)
Nicolas Patry's avatar
Nicolas Patry committed
70
71
72
73
        if use_medusa:
            from text_generation_server.utils.medusa import MedusaModel
            from huggingface_hub import hf_hub_download
            import json
PYNing's avatar
PYNing committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            import os
            from pathlib import Path
            
            is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
                "WEIGHTS_CACHE_OVERRIDE", None
            ) is not None
            
            if not is_local_model:
                medusa_config = hf_hub_download(
                    use_medusa, revision=revision, filename="config.json"
                )
                medusa_head = hf_hub_download(
                    use_medusa, revision=revision, filename="medusa_lm_head.pt"
                )
            else:
                medusa_config = str(Path(use_medusa) / "config.json")
                medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
                
Nicolas Patry's avatar
Nicolas Patry committed
92
93
            with open(medusa_config, "r") as f:
                config = json.load(f)
OlivierDehaene's avatar
OlivierDehaene committed
94
95
96
97
            medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
            weights = Weights(
                [medusa_sf], device, dtype, process_group=self.process_group
            )
Nicolas Patry's avatar
Nicolas Patry committed
98
99
            lm_head = model.lm_head
            model.lm_head = MedusaModel(config, weights, lm_head)
100
101

        torch.distributed.barrier(group=self.process_group)
102
        super(FlashLlama, self).__init__(
103
            model=model,
104
            tokenizer=tokenizer,
105
            num_layers=len(model.model.layers),
106
            num_kv_heads=model.model.num_key_value_heads,
107
            head_size=model.model.head_size,
108
            dtype=dtype,
109
            device=device,
110
111
            rank=rank,
            world_size=world_size,
112
        )