neuron.py 4.89 KB
Newer Older
1
2
3
"""Utilities for selecting and loading neuron models."""
import importlib
import os
4
from typing import Dict, Optional, Tuple
5
6
7

import torch
import torch.nn as nn
8
import transformers
9
10
from transformers import PretrainedConfig

11
12
13
14
15
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
16
17
18
19
20
21
22
23
24
25
26
27
28

TORCH_DTYPE_TO_NEURON_AMP = {
    "auto": "f32",
    "half": "f16",
    "float16": "f16",
    "bfloat16": "bf16",
    "float": "f32",
    "float32": "f32",
    torch.float16: "f16",
    torch.bfloat16: "bf16",
    torch.float32: "f32",
}

29
# Models supported by Neuron.
30
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    "LlamaForCausalLM": ("transformers_neuronx.llama.model",
                         "LlamaForSampling", "LlamaForCausalLM"),
    "MistralForCausalLM": ("transformers_neuronx.mistral.model",
                           "MistralForSampling", "MistralForCausalLM")
}


class NeuronCasualLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
    ) -> None:
        super().__init__()
        self.config = config
        self.logits_processor = LogitsProcessor(config.vocab_size,
                                                logits_as_input=True)
        self.sampler = Sampler()

50
51
52
        # Lazy initialized
        self.model: nn.Module

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_block_ids: torch.Tensor,
    ) -> torch.Tensor:
        logits = self.model(input_ids,
                            cache_ids=positions,
                            start_ids=input_block_ids)
        return logits

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(None, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, model_name_or_path: str, **kwargs):
        arch = _get_model_architecture(self.config)
79
        neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
80
81
            _NEURON_SUPPORTED_MODELS[arch])
        neuronx_module = importlib.import_module(neuronx_module_path)
82
        neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
83
84
85
86
87
88

        split_model_dir = f"{model_name_or_path}-split"
        if os.path.isdir(os.path.join(model_name_or_path,
                                      "pytorch_model.bin")):
            split_model_dir = model_name_or_path
        elif not os.path.exists(f"{model_name_or_path}-split"):
89
            hf_model_cls = getattr(transformers, hf_model_cls_name)
90
91
92
93
94
95
96
97
98
99
            from transformers_neuronx.module import save_pretrained_split

            hf_model = hf_model_cls.from_pretrained(model_name_or_path,
                                                    low_cpu_mem_usage=True)
            save_pretrained_split(hf_model, f"{model_name_or_path}-split")

        self.model = neuronx_model_cls.from_pretrained(split_model_dir,
                                                       **kwargs)
        self.model.to_neuron()

100

101
def _get_model_architecture(config: PretrainedConfig) -> str:
102
103
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
104
105
        if arch in _NEURON_SUPPORTED_MODELS:
            return arch
106
    raise ValueError(
107
108
109
        f"Model architectures {architectures} are not supported on Neuron "
        f"for now. Supported architectures: "
        f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
110
111


112
113
114
def get_neuron_model(model_config: ModelConfig,
                     parallel_config: ParallelConfig,
                     scheduler_config: SchedulerConfig) -> nn.Module:
115
116
    from transformers_neuronx.config import (ContinuousBatchingConfig,
                                             NeuronConfig)
117
118

    # Create a model instance.
119
    model = NeuronCasualLM(model_config.hf_config)
120
121
122
123
124
125
126
127
128

    continuous_batching_config = ContinuousBatchingConfig(
        batch_size_for_shared_caches=scheduler_config.max_num_seqs)
    neuron_config = NeuronConfig(
        continuous_batching=continuous_batching_config)

    # Load the weights from the cached or downloaded files.
    model.load_weights(
        model_config.model,
129
        tp_degree=parallel_config.tensor_parallel_size,
130
131
132
133
134
135
136
        amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
        neuron_config=neuron_config,
        context_length_estimate=[scheduler_config.max_model_len],
        n_positions=[scheduler_config.max_model_len],
        batch_size=scheduler_config.max_num_seqs)

    return model.eval()