neuron.py 8.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Utilities for selecting and loading neuron models."""
3
import copy
4
5
import importlib
import os
6
from typing import Dict, List, Optional, Tuple
7
8
9
10
11

import torch
import torch.nn as nn
from transformers import PretrainedConfig

12
13
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
14
from vllm.model_executor.layers.quantization import get_quantization_config
15
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
16
from vllm.model_executor.sampling_metadata import SamplingMetadata
17
18
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
                           SequenceOutput)
19
20
21
22
23
24
25
26
27
28
29
30
31

TORCH_DTYPE_TO_NEURON_AMP = {
    "auto": "f32",
    "half": "f16",
    "float16": "f16",
    "bfloat16": "bf16",
    "float": "f32",
    "float32": "f32",
    torch.float16: "f16",
    torch.bfloat16: "bf16",
    torch.float32: "f32",
}

32
# Models supported by Neuron.
33
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
34
35
36
37
38
39
40
    "LlamaForCausalLM": ("transformers_neuronx.llama.model",
                         "LlamaForSampling", "LlamaForCausalLM"),
    "MistralForCausalLM": ("transformers_neuronx.mistral.model",
                           "MistralForSampling", "MistralForCausalLM")
}


41
class NeuronCausalLM(nn.Module):
42

43
44
45
    def __init__(self,
                 config: PretrainedConfig,
                 on_device_sampling_disabled: bool = False) -> None:
46
47
48
49
        super().__init__()
        self.config = config
        self.logits_processor = LogitsProcessor(config.vocab_size,
                                                logits_as_input=True)
50
51
52
53
54

        self.on_device_sampling_disabled = on_device_sampling_disabled
        if self.on_device_sampling_disabled:
            # Use default sampler
            self.sampler = Sampler()
55

56
57
58
        # Lazy initialized
        self.model: nn.Module

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_block_ids: torch.Tensor,
    ) -> torch.Tensor:
        logits = self.model(input_ids,
                            cache_ids=positions,
                            start_ids=input_block_ids)
        return logits

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(None, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

        if self.on_device_sampling_disabled:
            next_tokens = self.sampler(logits, sampling_metadata)
            return next_tokens

        # On-device sampling outputs the token ids directly.
        sampled_token_ids = logits.flatten()
        next_tokens = []
        sample_idx = 0
        for seq_group in sampling_metadata.seq_groups:
            samples = []
            for seq_id in seq_group.seq_ids:
                token_id = sampled_token_ids[sample_idx].item()
                samples.append(
                    SequenceOutput(parent_seq_id=seq_id,
                                   output_token=token_id,
                                   logprobs={token_id: Logprob(token_id)}))
                sample_idx += 1
            next_tokens.append(
                CompletionSequenceGroupOutput(samples=samples,
                                              prompt_logprobs=None))

        return SamplerOutput(outputs=next_tokens)
103
104
105

    def load_weights(self, model_name_or_path: str, **kwargs):
        arch = _get_model_architecture(self.config)
106
        neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
107
108
            _NEURON_SUPPORTED_MODELS[arch])
        neuronx_module = importlib.import_module(neuronx_module_path)
109
        neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
110

111
        self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
112
113
114
                                                       **kwargs)
        self.model.to_neuron()

115

116
def _get_model_architecture(config: PretrainedConfig) -> str:
117
118
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
119
120
        if arch in _NEURON_SUPPORTED_MODELS:
            return arch
121
    raise ValueError(
122
123
124
        f"Model architectures {architectures} are not supported on Neuron "
        f"for now. Supported architectures: "
        f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
125
126


127
128
129
130
131
132
133
134
135
136
137
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
    env_value = os.getenv(env)
    if env_value is None:
        return default_value
    buckets_remove_empty = filter(
        lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
    buckets_int = map(int, buckets_remove_empty)
    buckets_list = list(buckets_int)
    return buckets_list


138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def _get_default_neuron_config(model_config: ModelConfig,
                               parallel_config: ParallelConfig,
                               scheduler_config: SchedulerConfig):
    from transformers_neuronx.config import ContinuousBatchingConfig
    from transformers_neuronx.constants import LAYOUT_BSH

    continuous_batching_config = ContinuousBatchingConfig(
        batch_size_for_shared_caches=scheduler_config.max_num_seqs)
    quant_config = dict(
        dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
        quantize_method="vector_dynamic")
    neuron_quantization_config_builder = lambda quant: get_quantization_config(
        quant).from_config(quant_config).get_quant_method(None, "")
    # TODO: Add Paged attention config to the default neuron arguments.
    default_neuron_args = dict(
        collectives_layout=LAYOUT_BSH,
        attention_layout=LAYOUT_BSH,
        fuse_qkv=True,
        quant=neuron_quantization_config_builder(model_config.quantization)
        if model_config.quantization else None,
        continuous_batching=continuous_batching_config,
159
160
161
        weight_tiling=bool(model_config.quantization),
        on_device_generation=_get_neuron_on_device_generation_config(
            model_config))
162
163
164
    return default_neuron_args


165
166
167
168
169
170
171
172
173
174
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
    if not _is_neuron_on_device_sampling_disabled(model_config):
        return copy.deepcopy(model_config.neuron_sampling_params)
    return None


def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
    return not getattr(model_config, "neuron_sampling_params", None)


175
176
177
178
179
180
181
182
def _get_neuron_config_after_override(default_neuron_config,
                                      overridden_neuron_config):
    from transformers_neuronx.config import NeuronConfig
    overridden_neuron_config = overridden_neuron_config or {}
    default_neuron_config.update(overridden_neuron_config)
    return NeuronConfig(**default_neuron_config)


183
184
185
def get_neuron_model(model_config: ModelConfig,
                     parallel_config: ParallelConfig,
                     scheduler_config: SchedulerConfig) -> nn.Module:
186
187

    # Create a model instance.
188
    model = NeuronCausalLM(
189
190
        model_config.hf_config,
        _is_neuron_on_device_sampling_disabled(model_config))
191

192
193
194
195
196
    default_neuron_config_args = _get_default_neuron_config(
        model_config, parallel_config, scheduler_config)

    neuron_config = _get_neuron_config_after_override(
        default_neuron_config_args, model_config.override_neuron_config)
197

198
199
200
201
202
    context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
                                            [scheduler_config.max_model_len])
    n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
                               [scheduler_config.max_model_len])

203
    # Load the weights from the cached or downloaded files.
204
205
206
207
208
209
210
    model.load_weights(model_config.model,
                       tp_degree=parallel_config.tensor_parallel_size,
                       amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
                       neuron_config=neuron_config,
                       context_length_estimate=context_length_estimates,
                       n_positions=n_positions,
                       batch_size=scheduler_config.max_num_seqs)
211
212

    return model.eval()