Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
...@@ -416,7 +416,7 @@ class GroupCoordinator: ...@@ -416,7 +416,7 @@ class GroupCoordinator:
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, ( assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same " "Invalid destination rank. Destination rank is the same "
"as the current rank.") "as the current rank.")
...@@ -446,7 +446,7 @@ class GroupCoordinator: ...@@ -446,7 +446,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank, ( assert src != self.rank_in_group, (
"Invalid source rank. Source rank is the same as the current rank." "Invalid source rank. Source rank is the same as the current rank."
) )
...@@ -454,7 +454,7 @@ class GroupCoordinator: ...@@ -454,7 +454,7 @@ class GroupCoordinator:
# Receive object size # Receive object size
rank_size = torch.distributed.recv(size_tensor, rank_size = torch.distributed.recv(size_tensor,
src=src, src=self.ranks[src],
group=self.cpu_group) group=self.cpu_group)
# Tensor to receive serialized objects into. # Tensor to receive serialized objects into.
...@@ -464,7 +464,7 @@ class GroupCoordinator: ...@@ -464,7 +464,7 @@ class GroupCoordinator:
device="cpu") device="cpu")
rank_object = torch.distributed.recv(object_tensor, rank_object = torch.distributed.recv(object_tensor,
src=src, src=self.ranks[src],
group=self.cpu_group) group=self.cpu_group)
assert rank_object == rank_size, ( assert rank_object == rank_size, (
...@@ -491,10 +491,9 @@ class GroupCoordinator: ...@@ -491,10 +491,9 @@ class GroupCoordinator:
group = self.device_group group = self.device_group
metadata_group = self.cpu_group metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
src = self.ranks[src]
rank = self.rank rank_in_group = self.rank_in_group
if rank == src: if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
...@@ -512,13 +511,13 @@ class GroupCoordinator: ...@@ -512,13 +511,13 @@ class GroupCoordinator:
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor, handle = torch.distributed.broadcast(tensor,
src=src, src=self.ranks[src],
group=metadata_group, group=metadata_group,
async_op=True) async_op=True)
else: else:
# use group for GPU tensors # use group for GPU tensors
handle = torch.distributed.broadcast(tensor, handle = torch.distributed.broadcast(tensor,
src=src, src=self.ranks[src],
group=group, group=group,
async_op=True) async_op=True)
async_handles.append(handle) async_handles.append(handle)
...@@ -542,13 +541,14 @@ class GroupCoordinator: ...@@ -542,13 +541,14 @@ class GroupCoordinator:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
handle = torch.distributed.broadcast( handle = torch.distributed.broadcast(
tensor, tensor,
src=src, src=self.ranks[src],
group=metadata_group, group=metadata_group,
async_op=True) async_op=True)
else: else:
# use group for GPU tensors # use group for GPU tensors
handle = torch.distributed.broadcast(tensor, handle = torch.distributed.broadcast(
src=src, tensor,
src=self.ranks[src],
group=group, group=group,
async_op=True) async_op=True)
async_handles.append(handle) async_handles.append(handle)
...@@ -575,7 +575,7 @@ class GroupCoordinator: ...@@ -575,7 +575,7 @@ class GroupCoordinator:
metadata_group = self.cpu_group metadata_group = self.cpu_group
if dst is None: if dst is None:
dst = self.next_rank dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: List[Tuple[Any, Any]] = []
...@@ -593,10 +593,14 @@ class GroupCoordinator: ...@@ -593,10 +593,14 @@ class GroupCoordinator:
continue continue
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group) torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else: else:
# use group for GPU tensors # use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group) torch.distributed.send(tensor,
dst=self.ranks[dst],
group=group)
return None return None
def recv_tensor_dict( def recv_tensor_dict(
...@@ -614,7 +618,7 @@ class GroupCoordinator: ...@@ -614,7 +618,7 @@ class GroupCoordinator:
metadata_group = self.cpu_group metadata_group = self.cpu_group
if src is None: if src is None:
src = self.prev_rank src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src) recv_metadata_list = self.recv_object(src=src)
...@@ -631,11 +635,13 @@ class GroupCoordinator: ...@@ -631,11 +635,13 @@ class GroupCoordinator:
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
torch.distributed.recv(tensor, torch.distributed.recv(tensor,
src=src, src=self.ranks[src],
group=metadata_group) group=metadata_group)
else: else:
# use group for GPU tensors # use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group) torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
_update_nested_dict(tensor_dict, key, tensor) _update_nested_dict(tensor_dict, key, tensor)
else: else:
_update_nested_dict(tensor_dict, key, value) _update_nested_dict(tensor_dict, key, value)
...@@ -654,7 +660,7 @@ class GroupCoordinator: ...@@ -654,7 +660,7 @@ class GroupCoordinator:
"""Sends a tensor to the destination rank in a non-blocking way""" """Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank.""" """NOTE: `dst` is the local rank of the destination rank."""
if dst is None: if dst is None:
dst = self.next_rank dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
...@@ -669,7 +675,7 @@ class GroupCoordinator: ...@@ -669,7 +675,7 @@ class GroupCoordinator:
"""Receives a tensor from the src rank.""" """Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank.""" """NOTE: `src` is the local rank of the destination rank."""
if src is None: if src is None:
src = self.prev_rank src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device) tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence from typing import Sequence, Tuple
import torch import torch
...@@ -46,3 +46,12 @@ def split_tensor_along_last_dim( ...@@ -46,3 +46,12 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list return tensor_list
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
layers_per_partition = divide(num_hidden_layers, pp_size)
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
return (start_layer, end_layer)
...@@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
async def step_async( async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
...@@ -221,7 +222,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -221,7 +222,8 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
...@@ -230,6 +232,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -230,6 +232,7 @@ class _AsyncLLMEngine(LLMEngine):
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size, running_queue_size=scheduler_outputs.running_queue_size,
) )
...@@ -248,16 +251,12 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -248,16 +251,12 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs return request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async( async def process_model_inputs_async(
self, self,
request_id: str, request_id: str,
...@@ -491,7 +490,8 @@ class AsyncLLMEngine: ...@@ -491,7 +490,8 @@ class AsyncLLMEngine:
# order of the arguments. # order of the arguments.
cache_config = kwargs["cache_config"] cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"] parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1: if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization num_gpus = cache_config.gpu_memory_utilization
else: else:
num_gpus = 1 num_gpus = 1
...@@ -499,7 +499,7 @@ class AsyncLLMEngine: ...@@ -499,7 +499,7 @@ class AsyncLLMEngine:
self._engine_class).remote self._engine_class).remote
return engine_class(*args, **kwargs) return engine_class(*args, **kwargs)
async def engine_step(self) -> bool: async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests. """Kick the engine to process the waiting requests.
Returns True if there are in-progress requests.""" Returns True if there are in-progress requests."""
...@@ -530,7 +530,7 @@ class AsyncLLMEngine: ...@@ -530,7 +530,7 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore request_outputs = await self.engine.step.remote() # type: ignore
else: else:
request_outputs = await self.engine.step_async() request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
for request_output in request_outputs: for request_output in request_outputs:
...@@ -546,18 +546,65 @@ class AsyncLLMEngine: ...@@ -546,18 +546,65 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
has_requests_in_progress = False if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True: while True:
if not has_requests_in_progress: if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...") logger.debug("Waiting for new requests...")
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!") logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
# Abort if iteration takes too long due to unrecoverable errors # Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts). # (eg. NCCL timeouts).
try: try:
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
has_requests_in_progress = await self.engine_step() done, _ = await asyncio.wait(
requests_in_progress,
return_when=asyncio.FIRST_COMPLETED)
for _ in range(pipeline_parallel_size):
await asyncio.sleep(0)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc: except asyncio.TimeoutError as exc:
logger.error( logger.error(
"Engine iteration timed out. This should never happen!") "Engine iteration timed out. This should never happen!")
......
...@@ -173,6 +173,7 @@ class LLMEngine: ...@@ -173,6 +173,7 @@ class LLMEngine:
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, " "enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
...@@ -195,6 +196,7 @@ class LLMEngine: ...@@ -195,6 +196,7 @@ class LLMEngine:
load_config.download_dir, load_config.download_dir,
load_config.load_format, load_config.load_format,
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce, parallel_config.disable_custom_all_reduce,
model_config.quantization, model_config.quantization,
model_config.enforce_eager, model_config.enforce_eager,
...@@ -296,7 +298,11 @@ class LLMEngine: ...@@ -296,7 +298,11 @@ class LLMEngine:
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging. # Metric Logging.
if self.log_stats: if self.log_stats:
...@@ -513,8 +519,16 @@ class LLMEngine: ...@@ -513,8 +519,16 @@ class LLMEngine:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler with least unfinished seqs.
self.scheduler.add_seq_group(seq_group) costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
def process_model_inputs( def process_model_inputs(
self, self,
...@@ -684,7 +698,8 @@ class LLMEngine: ...@@ -684,7 +698,8 @@ class LLMEngine:
>>> # abort the request >>> # abort the request
>>> engine.abort_request(request_id) >>> engine.abort_request(request_id)
""" """
self.scheduler.abort_seq_group(request_id) for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig: def get_model_config(self) -> ModelConfig:
"""Gets the model configuration.""" """Gets the model configuration."""
...@@ -696,11 +711,20 @@ class LLMEngine: ...@@ -696,11 +711,20 @@ class LLMEngine:
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return sum(scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler)
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return any(scheduler.has_unfinished_seqs()
for scheduler in self.scheduler)
def has_unfinished_requests_for_virtual_engine(
self, virtual_engine: int) -> bool:
"""
Returns True if there are unfinished requests for the virtual engine.
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def _process_sequence_group_outputs( def _process_sequence_group_outputs(
self, self,
...@@ -749,7 +773,8 @@ class LLMEngine: ...@@ -749,7 +773,8 @@ class LLMEngine:
self.output_processor.process_outputs(seq_group, outputs) self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[Union[RequestOutput, request_outputs: List[Union[RequestOutput,
...@@ -815,7 +840,12 @@ class LLMEngine: ...@@ -815,7 +840,12 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs): >>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break >>> break
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
...@@ -886,23 +916,28 @@ class LLMEngine: ...@@ -886,23 +916,28 @@ class LLMEngine:
# System State # System State
# Scheduler State # Scheduler State
num_running_sys = len(self.scheduler.running) num_running_sys = sum(
num_swapped_sys = len(self.scheduler.swapped) len(scheduler.running) for scheduler in self.scheduler)
num_waiting_sys = len(self.scheduler.waiting) num_swapped_sys = sum(
len(scheduler.swapped) for scheduler in self.scheduler)
num_waiting_sys = sum(
len(scheduler.waiting) for scheduler in self.scheduler)
# KV Cache Usage in % # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
gpu_cache_usage_sys = 0. gpu_cache_usage_sys = 0.
if num_total_gpu is not None: if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( num_free_gpu = sum(
) scheduler.block_manager.get_num_free_gpu_blocks()
for scheduler in self.scheduler)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0. cpu_cache_usage_sys = 0.
if num_total_cpu is not None and num_total_cpu > 0: if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( num_free_cpu = sum(
) scheduler.block_manager.get_num_free_cpu_blocks()
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Iteration stats # Iteration stats
......
...@@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC):
def create_output_processor( def create_output_processor(
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
......
...@@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def __init__( def __init__(
self, self,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
...@@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
break break
if seq.is_finished(): if seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
...@@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
...@@ -95,7 +95,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -95,7 +95,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# not be used in the future iterations. # not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id) seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent) for scheduler in self.scheduler:
scheduler.free_seq(parent)
continue continue
# Fork the parent sequence if there are multiple child samples. # Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]: for child_sample in child_samples[:-1]:
...@@ -133,7 +134,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -133,7 +134,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
self.scheduler.fork_seq(parent, seq) for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
...@@ -141,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -141,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# old sequences. # old sequences.
for seq, parent in child_seqs: for seq, parent in child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
return return
# Beam search case # Beam search case
...@@ -226,13 +229,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -226,13 +229,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
self.scheduler.fork_seq(parent, seq) for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs: for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and # Remove the unselected parent sequences from the sequence group and
# free their memory in block manager. # free their memory in block manager.
...@@ -241,7 +246,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -241,7 +246,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Remove the parent sequence if it is not selected for next # Remove the parent sequence if it is not selected for next
# iteration # iteration
seq_group.remove(seq.seq_id) seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping( def _check_beam_search_early_stopping(
self, self,
......
...@@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor):
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers( self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop", "start_worker_execution_loop",
async_run_remote_workers_only=True, async_run_tensor_parallel_workers_only=True,
**self.extra_execute_model_run_workers_kwargs) **self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
...@@ -138,17 +138,17 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -138,17 +138,17 @@ class DistributedGPUExecutor(GPUExecutor):
self, self,
method: str, method: str,
*args, *args,
async_run_remote_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. """Runs the given method on all workers.
Args: Args:
async_run_remote_workers_only: If True the method will be run only async_run_tensor_parallel_workers_only: If True the method will be
in the remote workers, not the driver worker. It will also be run only in the remote TP workers, not the driver worker.
run asynchronously and return a list of futures rather than It will also be run asynchronously and return a list of futures
blocking on the results. rather than blocking on the results.
""" """
raise NotImplementedError raise NotImplementedError
......
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
...@@ -110,6 +111,30 @@ class ExecutorBase(ABC): ...@@ -110,6 +111,30 @@ class ExecutorBase(ABC):
class ExecutorAsyncBase(ExecutorBase): class ExecutorAsyncBase(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
# This locks each pipeline parallel stage so multiple virtual engines
# can't execute on the same stage at the same time
self.pp_locks = [
asyncio.Lock()
for _ in range(parallel_config.pipeline_parallel_size)
]
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, vision_language_config,
speculative_config)
@abstractmethod @abstractmethod
async def execute_model_async( async def execute_model_async(
self, self,
......
...@@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase): ...@@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase):
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config, speculative_config=self.speculative_config,
is_driver_worker=rank == 0, is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
) )
def _create_worker(self, def _create_worker(self,
......
...@@ -91,17 +91,17 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -91,17 +91,17 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
self, self,
method: str, method: str,
*args, *args,
async_run_remote_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. """Runs the given method on all workers.
Args: Args:
async_run_remote_workers_only: If True the method will be run only async_run_tensor_parallel_workers_only: If True the method will be
in the remote workers, not the driver worker. It will also be run only in the remote TP workers, not the driver worker.
run asynchronously and return a list of futures rather than It will also be run asynchronously and return a list of futures
blocking on the results. rather than blocking on the results.
""" """
if max_concurrent_workers: if max_concurrent_workers:
...@@ -114,7 +114,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -114,7 +114,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
for worker in self.workers for worker in self.workers
] ]
if async_run_remote_workers_only: if async_run_tensor_parallel_workers_only:
# Just return futures # Just return futures
return worker_outputs return worker_outputs
......
...@@ -62,7 +62,8 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -62,7 +62,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if (self.parallel_config.tensor_parallel_size == 1
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory. # For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization num_gpus = self.cache_config.gpu_memory_utilization
else: else:
...@@ -189,6 +190,26 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -189,6 +190,26 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
for tp_rank in range(self.parallel_config.tensor_parallel_size):
rank = (pp_rank *
self.parallel_config.tensor_parallel_size) + tp_rank
if rank == 0:
pass
elif rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[rank - 1])
else:
self.non_driver_workers.append(self.workers[rank - 1])
def _driver_execute_model( def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
...@@ -204,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -204,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
self, self,
method: str, method: str,
*args, *args,
async_run_remote_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None, all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
...@@ -215,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -215,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
ways: ways:
- async_run_remote_workers_only: If True the method will be run only Args:
in the remote workers, not the driver worker. It will also be - async_run_tensor_parallel_workers_only: If True the method will be
run asynchronously and return a list of futures rather than blocking run only in the remote TP workers, not the driver worker.
on the results. It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs - args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified - all_args/all_kwargs: args/kwargs for each worker are specified
individually individually
...@@ -228,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -228,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
count = len(self.workers) count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
all_worker_args = repeat(args, count) if all_args is None \ all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None) else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
...@@ -242,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -242,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_worker_outputs = [] ray_worker_outputs = []
else: else:
# Start the ray workers first. # Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, worker.execute_method.remote(method, *worker_args,
**worker_kwargs) **worker_kwargs)
for (worker, worker_args, worker_kwargs for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs) ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
] ]
if async_run_remote_workers_only: if async_run_tensor_parallel_workers_only:
# Just return futures # Just return futures
return ray_worker_outputs return ray_worker_outputs
...@@ -319,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -319,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req) async def _run_task_with_lock(task, lock, *args, **kwargs):
async with lock:
return await task(*args, **kwargs)
tasks = []
tasks.append(
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req)))
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self): async def _start_worker_execution_loop(self):
coros = [ coros = [
worker.execute_method.remote("start_worker_execution_loop") worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers for worker in self.non_driver_workers
] ]
return await asyncio.gather(*coros) return await asyncio.gather(*coros)
...@@ -29,7 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -29,7 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -426,6 +426,7 @@ class ArcticForCausalLM(nn.Module): ...@@ -426,6 +426,7 @@ class ArcticForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -338,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -338,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -286,6 +286,7 @@ class BloomForCausalLM(nn.Module): ...@@ -286,6 +286,7 @@ class BloomForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -365,6 +365,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -365,6 +365,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -46,7 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,7 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
@torch.compile @torch.compile
...@@ -353,6 +353,7 @@ class CohereForCausalLM(nn.Module): ...@@ -353,6 +353,7 @@ class CohereForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
...@@ -381,6 +381,7 @@ class DbrxForCausalLM(nn.Module): ...@@ -381,6 +381,7 @@ class DbrxForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -387,6 +387,7 @@ class DeepseekForCausalLM(nn.Module): ...@@ -387,6 +387,7 @@ class DeepseekForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -475,6 +475,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -475,6 +475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment