neuron.py 8.1 KB
Newer Older
1
"""Utilities for selecting and loading neuron models."""
2
import copy
3
4
import importlib
import os
5
from typing import Dict, List, Optional, Tuple
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import PretrainedConfig

11
12
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
13
from vllm.model_executor.layers.quantization import get_quantization_config
14
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
15
from vllm.model_executor.sampling_metadata import SamplingMetadata
16
17
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
                           SequenceOutput)
18
19
20
21
22
23
24
25
26
27
28
29
30

TORCH_DTYPE_TO_NEURON_AMP = {
    "auto": "f32",
    "half": "f16",
    "float16": "f16",
    "bfloat16": "bf16",
    "float": "f32",
    "float32": "f32",
    torch.float16: "f16",
    torch.bfloat16: "bf16",
    torch.float32: "f32",
}

31
# Models supported by Neuron.
32
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
33
34
35
36
37
38
39
    "LlamaForCausalLM": ("transformers_neuronx.llama.model",
                         "LlamaForSampling", "LlamaForCausalLM"),
    "MistralForCausalLM": ("transformers_neuronx.mistral.model",
                           "MistralForSampling", "MistralForCausalLM")
}


40
class NeuronCausalLM(nn.Module):
41

42
43
44
    def __init__(self,
                 config: PretrainedConfig,
                 on_device_sampling_disabled: bool = False) -> None:
45
46
47
48
        super().__init__()
        self.config = config
        self.logits_processor = LogitsProcessor(config.vocab_size,
                                                logits_as_input=True)
49
50
51
52
53

        self.on_device_sampling_disabled = on_device_sampling_disabled
        if self.on_device_sampling_disabled:
            # Use default sampler
            self.sampler = Sampler()
54

55
56
57
        # Lazy initialized
        self.model: nn.Module

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_block_ids: torch.Tensor,
    ) -> torch.Tensor:
        logits = self.model(input_ids,
                            cache_ids=positions,
                            start_ids=input_block_ids)
        return logits

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(None, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

        if self.on_device_sampling_disabled:
            next_tokens = self.sampler(logits, sampling_metadata)
            return next_tokens

        # On-device sampling outputs the token ids directly.
        sampled_token_ids = logits.flatten()
        next_tokens = []
        sample_idx = 0
        for seq_group in sampling_metadata.seq_groups:
            samples = []
            for seq_id in seq_group.seq_ids:
                token_id = sampled_token_ids[sample_idx].item()
                samples.append(
                    SequenceOutput(parent_seq_id=seq_id,
                                   output_token=token_id,
                                   logprobs={token_id: Logprob(token_id)}))
                sample_idx += 1
            next_tokens.append(
                CompletionSequenceGroupOutput(samples=samples,
                                              prompt_logprobs=None))

        return SamplerOutput(outputs=next_tokens)
102
103
104

    def load_weights(self, model_name_or_path: str, **kwargs):
        arch = _get_model_architecture(self.config)
105
        neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
106
107
            _NEURON_SUPPORTED_MODELS[arch])
        neuronx_module = importlib.import_module(neuronx_module_path)
108
        neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
109

110
        self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
111
112
113
                                                       **kwargs)
        self.model.to_neuron()

114

115
def _get_model_architecture(config: PretrainedConfig) -> str:
116
117
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
118
119
        if arch in _NEURON_SUPPORTED_MODELS:
            return arch
120
    raise ValueError(
121
122
123
        f"Model architectures {architectures} are not supported on Neuron "
        f"for now. Supported architectures: "
        f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
124
125


126
127
128
129
130
131
132
133
134
135
136
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
    env_value = os.getenv(env)
    if env_value is None:
        return default_value
    buckets_remove_empty = filter(
        lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
    buckets_int = map(int, buckets_remove_empty)
    buckets_list = list(buckets_int)
    return buckets_list


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def _get_default_neuron_config(model_config: ModelConfig,
                               parallel_config: ParallelConfig,
                               scheduler_config: SchedulerConfig):
    from transformers_neuronx.config import ContinuousBatchingConfig
    from transformers_neuronx.constants import LAYOUT_BSH

    continuous_batching_config = ContinuousBatchingConfig(
        batch_size_for_shared_caches=scheduler_config.max_num_seqs)
    quant_config = dict(
        dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
        quantize_method="vector_dynamic")
    neuron_quantization_config_builder = lambda quant: get_quantization_config(
        quant).from_config(quant_config).get_quant_method(None, "")
    # TODO: Add Paged attention config to the default neuron arguments.
    default_neuron_args = dict(
        collectives_layout=LAYOUT_BSH,
        attention_layout=LAYOUT_BSH,
        fuse_qkv=True,
        quant=neuron_quantization_config_builder(model_config.quantization)
        if model_config.quantization else None,
        continuous_batching=continuous_batching_config,
158
159
160
        weight_tiling=bool(model_config.quantization),
        on_device_generation=_get_neuron_on_device_generation_config(
            model_config))
161
162
163
    return default_neuron_args


164
165
166
167
168
169
170
171
172
173
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
    if not _is_neuron_on_device_sampling_disabled(model_config):
        return copy.deepcopy(model_config.neuron_sampling_params)
    return None


def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
    return not getattr(model_config, "neuron_sampling_params", None)


174
175
176
177
178
179
180
181
def _get_neuron_config_after_override(default_neuron_config,
                                      overridden_neuron_config):
    from transformers_neuronx.config import NeuronConfig
    overridden_neuron_config = overridden_neuron_config or {}
    default_neuron_config.update(overridden_neuron_config)
    return NeuronConfig(**default_neuron_config)


182
183
184
def get_neuron_model(model_config: ModelConfig,
                     parallel_config: ParallelConfig,
                     scheduler_config: SchedulerConfig) -> nn.Module:
185
186

    # Create a model instance.
187
    model = NeuronCausalLM(
188
189
        model_config.hf_config,
        _is_neuron_on_device_sampling_disabled(model_config))
190

191
192
193
194
195
    default_neuron_config_args = _get_default_neuron_config(
        model_config, parallel_config, scheduler_config)

    neuron_config = _get_neuron_config_after_override(
        default_neuron_config_args, model_config.override_neuron_config)
196

197
198
199
200
201
    context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
                                            [scheduler_config.max_model_len])
    n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
                               [scheduler_config.max_model_len])

202
    # Load the weights from the cached or downloaded files.
203
204
205
206
207
208
209
    model.load_weights(model_config.model,
                       tp_degree=parallel_config.tensor_parallel_size,
                       amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
                       neuron_config=neuron_config,
                       context_length_estimate=context_length_estimates,
                       n_positions=n_positions,
                       batch_size=scheduler_config.max_num_seqs)
210
211

    return model.eval()