neuron.py 5.56 KB
Newer Older
1
2
3
"""Utilities for selecting and loading neuron models."""
import importlib
import os
4
from typing import Dict, List, Optional, Tuple
5
6
7

import torch
import torch.nn as nn
8
import transformers
9
10
from transformers import PretrainedConfig

11
12
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
13
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
14
from vllm.model_executor.sampling_metadata import SamplingMetadata
15
16
17
18
19
20
21
22
23
24
25
26
27

TORCH_DTYPE_TO_NEURON_AMP = {
    "auto": "f32",
    "half": "f16",
    "float16": "f16",
    "bfloat16": "bf16",
    "float": "f32",
    "float32": "f32",
    torch.float16: "f16",
    torch.bfloat16: "bf16",
    torch.float32: "f32",
}

28
# Models supported by Neuron.
29
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    "LlamaForCausalLM": ("transformers_neuronx.llama.model",
                         "LlamaForSampling", "LlamaForCausalLM"),
    "MistralForCausalLM": ("transformers_neuronx.mistral.model",
                           "MistralForSampling", "MistralForCausalLM")
}


class NeuronCasualLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
    ) -> None:
        super().__init__()
        self.config = config
        self.logits_processor = LogitsProcessor(config.vocab_size,
                                                logits_as_input=True)
        self.sampler = Sampler()

49
50
51
        # Lazy initialized
        self.model: nn.Module

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_block_ids: torch.Tensor,
    ) -> torch.Tensor:
        logits = self.model(input_ids,
                            cache_ids=positions,
                            start_ids=input_block_ids)
        return logits

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(None, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, model_name_or_path: str, **kwargs):
        arch = _get_model_architecture(self.config)
78
        neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
79
80
            _NEURON_SUPPORTED_MODELS[arch])
        neuronx_module = importlib.import_module(neuronx_module_path)
81
        neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
82
83
84
85
86
87

        split_model_dir = f"{model_name_or_path}-split"
        if os.path.isdir(os.path.join(model_name_or_path,
                                      "pytorch_model.bin")):
            split_model_dir = model_name_or_path
        elif not os.path.exists(f"{model_name_or_path}-split"):
88
            hf_model_cls = getattr(transformers, hf_model_cls_name)
89
90
91
92
93
94
95
96
97
98
            from transformers_neuronx.module import save_pretrained_split

            hf_model = hf_model_cls.from_pretrained(model_name_or_path,
                                                    low_cpu_mem_usage=True)
            save_pretrained_split(hf_model, f"{model_name_or_path}-split")

        self.model = neuronx_model_cls.from_pretrained(split_model_dir,
                                                       **kwargs)
        self.model.to_neuron()

99

100
def _get_model_architecture(config: PretrainedConfig) -> str:
101
102
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
103
104
        if arch in _NEURON_SUPPORTED_MODELS:
            return arch
105
    raise ValueError(
106
107
108
        f"Model architectures {architectures} are not supported on Neuron "
        f"for now. Supported architectures: "
        f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
109
110


111
112
113
114
115
116
117
118
119
120
121
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
    env_value = os.getenv(env)
    if env_value is None:
        return default_value
    buckets_remove_empty = filter(
        lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
    buckets_int = map(int, buckets_remove_empty)
    buckets_list = list(buckets_int)
    return buckets_list


122
123
124
def get_neuron_model(model_config: ModelConfig,
                     parallel_config: ParallelConfig,
                     scheduler_config: SchedulerConfig) -> nn.Module:
125
126
    from transformers_neuronx.config import (ContinuousBatchingConfig,
                                             NeuronConfig)
127
128

    # Create a model instance.
129
    model = NeuronCasualLM(model_config.hf_config)
130
131
132
133
134
135

    continuous_batching_config = ContinuousBatchingConfig(
        batch_size_for_shared_caches=scheduler_config.max_num_seqs)
    neuron_config = NeuronConfig(
        continuous_batching=continuous_batching_config)

136
137
138
139
140
    context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
                                            [scheduler_config.max_model_len])
    n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
                               [scheduler_config.max_model_len])

141
    # Load the weights from the cached or downloaded files.
142
143
144
145
146
147
148
    model.load_weights(model_config.model,
                       tp_degree=parallel_config.tensor_parallel_size,
                       amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
                       neuron_config=neuron_config,
                       context_length_estimate=context_length_estimates,
                       n_positions=n_positions,
                       batch_size=scheduler_config.max_num_seqs)
149
150

    return model.eval()