"examples/vscode:/vscode.git/clone" did not exist on "25e48a3aae35849fd777f8a48c3c494337c11d83"
neuron_worker.py 1.61 KB
Newer Older
1
"""A Neuron worker class."""
2
from typing import List, Optional
3
4
5
6

import torch
import torch.distributed

7
8
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
9
10
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
11
from vllm.worker.neuron_model_runner import NeuronModelRunner
12
13


14
class NeuronWorker:
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    """A worker class that executes the model on a group of neuron cores.
    """

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
    ) -> None:
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config

30
31
        self.model_runner = NeuronModelRunner(model_config, parallel_config,
                                              scheduler_config, device_config)
32

33
34
    def init_device(self) -> None:
        # Set random seed.
35
36
37
38
39
40
41
42
        set_random_seed(self.model_config.seed)

    def load_model(self):
        self.model_runner.load_model()

    @torch.inference_mode()
    def execute_model(
        self,
43
        seq_group_metadata_list: List[SequenceGroupMetadata],
44
    ) -> Optional[SamplerOutput]:
45
        num_seq_groups = len(seq_group_metadata_list)
46
47
48
49
50

        # If there is no input, we don't need to execute the model.
        if num_seq_groups == 0:
            return {}

51
        output = self.model_runner.execute_model(seq_group_metadata_list)
52
        return output