"vscode:/vscode.git/clone" did not exist on "bf6a8dc2156b9761e7bcdd0df605cc1d875f8435"
Unverified Commit fd4ea8ef authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Use NCCL instead of ray for control-plane communication to remove serialization overhead (#2221)

parent 1066cbd1
...@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module): ...@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module): ...@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -320,7 +320,7 @@ class MixtralModel(nn.Module): ...@@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
...@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
self, self,
hidden_states: Optional[torch.Tensor], hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module): ...@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module): ...@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module): ...@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
head = self.lm_head.linear head = self.lm_head.linear
next_tokens = self.sampler(head.weight, hidden_states, next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias) sampling_metadata, head.bias)
......
...@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module): ...@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
import torch import torch
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
) )
...@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1): ...@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
(world_size * input_size[dim], ) + (world_size * input_size[dim], ) +
input_size[dim + 1:]) input_size[dim + 1:])
return output_tensor return output_tensor
def tensor_model_parallel_gather(input_, dst=0, dim=-1):
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(input_, src=0):
"""Broadcast the input tensor."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src)
return input_
def broadcast_object_list(obj_list, src=0):
"""Broadcast the input object list."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src)
return obj_list
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -18,24 +18,29 @@ class SamplingMetadata: ...@@ -18,24 +18,29 @@ class SamplingMetadata:
seq_data: Seq_id -> SequenceData. seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts. prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling. selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample. categorized_sample_indices: SamplingType -> token indices to sample.
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
""" """
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]], seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
seq_data: Dict[int, SequenceData], seq_data: Optional[Dict[int, SequenceData]],
prompt_lens: List[int], prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor], categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
perform_sampling: bool = True,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_data = seq_data self.seq_data = seq_data
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
...@@ -44,7 +49,8 @@ class SamplingMetadata: ...@@ -44,7 +49,8 @@ class SamplingMetadata:
f"seq_data={self.seq_data}, " f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, " f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, " f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices})") f"categorized_sample_indices={self.categorized_sample_indices}), "
f"perform_sampling={self.perform_sampling})")
@dataclass @dataclass
......
import enum import enum
import os
import socket import socket
import uuid import uuid
from platform import uname from platform import uname
from typing import List
import psutil import psutil
import torch import torch
...@@ -55,7 +57,15 @@ def in_wsl() -> bool: ...@@ -55,7 +57,15 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def get_open_port(): def get_ip() -> str:
return socket.gethostbyname(socket.gethostname())
def get_open_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
import time import time
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -8,6 +8,8 @@ import torch.nn as nn ...@@ -8,6 +8,8 @@ import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
broadcast, broadcast_object_list)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import in_wsl from vllm.utils import in_wsl
...@@ -28,10 +30,12 @@ class ModelRunner: ...@@ -28,10 +30,12 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
...@@ -70,7 +74,7 @@ class ModelRunner: ...@@ -70,7 +74,7 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
...@@ -135,14 +139,14 @@ class ModelRunner: ...@@ -135,14 +139,14 @@ class ModelRunner:
dtype=torch.long) dtype=torch.long)
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=prompt_lens, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
max_context_len=None, max_context_len=None,
context_lens=None, context_lens=None,
block_tables=None, block_tables=None,
use_cuda_graph=False, use_cuda_graph=False,
) )
return input_tokens, input_positions, input_metadata return input_tokens, input_positions, input_metadata, prompt_lens
def _prepare_decode( def _prepare_decode(
self, self,
...@@ -203,32 +207,24 @@ class ModelRunner: ...@@ -203,32 +207,24 @@ class ModelRunner:
block_tables.append([]) block_tables.append([])
batch_size = graph_batch_size batch_size = graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
pin_memory = use_captured_graph and not self.in_wsl
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=device, device="cuda")
pin_memory=pin_memory)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=device, device="cuda")
pin_memory=pin_memory)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1, max_len=1,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device=device, device="cuda")
pin_memory=pin_memory)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=device, device="cuda")
pin_memory=pin_memory)
if use_captured_graph: if use_captured_graph:
# The shape of graph_block_tables is # The shape of graph_block_tables is
...@@ -237,17 +233,18 @@ class ModelRunner: ...@@ -237,17 +233,18 @@ class ModelRunner:
for i, block_table in enumerate(block_tables): for i, block_table in enumerate(block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device) block_tables = torch.tensor(input_block_tables, device="cuda")
else: else:
block_tables = _make_tensor_with_pad( block_tables = _make_tensor_with_pad(
block_tables, block_tables,
max_len=max_context_len, max_len=max_context_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device="cuda",
) )
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=[], is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
max_context_len=max_context_len, max_context_len=max_context_len,
context_lens=context_lens, context_lens=context_lens,
...@@ -326,23 +323,127 @@ class ModelRunner: ...@@ -326,23 +323,127 @@ class ModelRunner:
) )
return sampling_metadata return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata
) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
def get_size_or_none(x: Optional[torch.Tensor]):
return x.size() if x is not None else None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data = {
"input_tokens_size":
input_tokens.size(),
"input_positions_size":
input_positions.size(),
"is_prompt":
input_metadata.is_prompt,
"slot_mapping_size":
get_size_or_none(input_metadata.slot_mapping),
"max_context_len":
input_metadata.max_context_len,
"context_lens_size":
get_size_or_none(input_metadata.context_lens),
"block_tables_size":
get_size_or_none(input_metadata.block_tables),
"use_cuda_graph":
input_metadata.use_cuda_graph,
"selected_token_indices_size":
sampling_metadata.selected_token_indices.size(),
}
broadcast_object_list([py_data], src=0)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast(input_tokens, src=0)
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
broadcast(input_metadata.slot_mapping, src=0)
if input_metadata.context_lens is not None:
broadcast(input_metadata.context_lens, src=0)
if input_metadata.block_tables is not None:
broadcast(input_metadata.block_tables, src=0)
broadcast(sampling_metadata.selected_token_indices, src=0)
else:
receving_list = [None]
broadcast_object_list(receving_list, src=0)
py_data = receving_list[0]
input_tokens = torch.empty(*py_data["input_tokens_size"],
dtype=torch.long,
device="cuda")
broadcast(input_tokens, src=0)
input_positions = torch.empty(*py_data["input_positions_size"],
dtype=torch.long,
device="cuda")
broadcast(input_positions, src=0)
if py_data["slot_mapping_size"] is not None:
slot_mapping = torch.empty(*py_data["slot_mapping_size"],
dtype=torch.long,
device="cuda")
broadcast(slot_mapping, src=0)
else:
slot_mapping = None
if py_data["context_lens_size"] is not None:
context_lens = torch.empty(*py_data["context_lens_size"],
dtype=torch.int,
device="cuda")
broadcast(context_lens, src=0)
else:
context_lens = None
if py_data["block_tables_size"] is not None:
block_tables = torch.empty(*py_data["block_tables_size"],
dtype=torch.int,
device="cuda")
broadcast(block_tables, src=0)
else:
block_tables = None
selected_token_indices = torch.empty(
*py_data["selected_token_indices_size"],
dtype=torch.long,
device="cuda")
broadcast(selected_token_indices, src=0)
input_metadata = InputMetadata(
is_prompt=py_data["is_prompt"],
slot_mapping=slot_mapping,
max_context_len=py_data["max_context_len"],
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=py_data["use_cuda_graph"],
)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
perform_sampling=False,
)
return input_tokens, input_positions, input_metadata, sampling_metadata
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
# NOTE: We assume that all sequences in the group are all prompts or input_tokens, input_positions, input_metadata, sampling_metadata = (
# all decodes. self.prepare_input_tensors(seq_group_metadata_list))
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
# Execute the model. # Execute the model.
if input_metadata.use_cuda_graph: if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
...@@ -356,9 +457,6 @@ class ModelRunner: ...@@ -356,9 +457,6 @@ class ModelRunner:
input_metadata=input_metadata, input_metadata=input_metadata,
) )
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -424,7 +522,7 @@ class ModelRunner: ...@@ -424,7 +522,7 @@ class ModelRunner:
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
# Create dummy input_metadata. # Create dummy input_metadata.
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=[], is_prompt=False,
slot_mapping=slot_mapping[:batch_size], slot_mapping=slot_mapping[:batch_size],
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
......
...@@ -8,6 +8,8 @@ import torch.distributed ...@@ -8,6 +8,8 @@ import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_object_list)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel) initialize_model_parallel)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
...@@ -28,17 +30,23 @@ class Worker: ...@@ -28,17 +30,23 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
rank: Optional[int] = None, local_rank: int,
distributed_init_method: Optional[str] = None, rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, parallel_config, self.model_runner = ModelRunner(model_config, parallel_config,
scheduler_config) scheduler_config, is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
...@@ -57,13 +65,7 @@ class Worker: ...@@ -57,13 +65,7 @@ class Worker:
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray. self.device = torch.device(f"cuda:{self.local_rank}")
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
...@@ -125,14 +127,12 @@ class Worker: ...@@ -125,14 +127,12 @@ class Worker:
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
@torch.inference_mode() def cache_swap(
def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput: ) -> None:
# Issue cache operations. # Issue cache operations.
issued_cache_op = False issued_cache_op = False
if blocks_to_swap_in: if blocks_to_swap_in:
...@@ -152,8 +152,38 @@ class Worker: ...@@ -152,8 +152,38 @@ class Worker:
if cache_events is not None: if cache_events is not None:
for event in cache_events: for event in cache_events:
event.wait() event.wait()
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
block_swapping_info = [
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy
]
broadcast_object_list([num_seq_groups] + block_swapping_info,
src=0)
else:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data = [None] * 4
broadcast_object_list(recv_data, src=0)
num_seq_groups = recv_data[0]
block_swapping_info = recv_data[1:]
self.cache_swap(*block_swapping_info)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if not seq_group_metadata_list: if num_seq_groups == 0:
return {} return {}
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment