Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
......@@ -416,7 +416,7 @@ class GroupCoordinator:
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, (
assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")
......@@ -446,7 +446,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank, (
assert src != self.rank_in_group, (
"Invalid source rank. Source rank is the same as the current rank."
)
......@@ -454,7 +454,7 @@ class GroupCoordinator:
# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=src,
src=self.ranks[src],
group=self.cpu_group)
# Tensor to receive serialized objects into.
......@@ -464,7 +464,7 @@ class GroupCoordinator:
device="cpu")
rank_object = torch.distributed.recv(object_tensor,
src=src,
src=self.ranks[src],
group=self.cpu_group)
assert rank_object == rank_size, (
......@@ -491,10 +491,9 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
src = self.ranks[src]
rank = self.rank
if rank == src:
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
......@@ -512,13 +511,13 @@ class GroupCoordinator:
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
src=self.ranks[src],
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
......@@ -542,15 +541,16 @@ class GroupCoordinator:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=src,
src=self.ranks[src],
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor)
else:
......@@ -575,7 +575,7 @@ class GroupCoordinator:
metadata_group = self.cpu_group
if dst is None:
dst = self.next_rank
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
......@@ -593,10 +593,14 @@ class GroupCoordinator:
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group)
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group)
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=group)
return None
def recv_tensor_dict(
......@@ -614,7 +618,7 @@ class GroupCoordinator:
metadata_group = self.cpu_group
if src is None:
src = self.prev_rank
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
......@@ -631,11 +635,13 @@ class GroupCoordinator:
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=src,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
......@@ -654,7 +660,7 @@ class GroupCoordinator:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = self.next_rank
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
......@@ -669,7 +675,7 @@ class GroupCoordinator:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if src is None:
src = self.prev_rank
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
......
......@@ -2,7 +2,7 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence
from typing import Sequence, Tuple
import torch
......@@ -46,3 +46,12 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
layers_per_partition = divide(num_hidden_layers, pp_size)
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
return (start_layer, end_layer)
......@@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
......@@ -221,7 +222,8 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if not scheduler_outputs.is_empty():
# Execute the model.
......@@ -230,6 +232,7 @@ class _AsyncLLMEngine(LLMEngine):
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
......@@ -248,16 +251,12 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing
self.do_tracing(scheduler_outputs)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async(
self,
request_id: str,
......@@ -491,7 +490,8 @@ class AsyncLLMEngine:
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1:
if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
......@@ -499,7 +499,7 @@ class AsyncLLMEngine:
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self) -> bool:
async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
......@@ -530,7 +530,7 @@ class AsyncLLMEngine:
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
......@@ -546,18 +546,65 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
has_requests_in_progress = False
if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
if not has_requests_in_progress:
if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...")
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
has_requests_in_progress = await self.engine_step()
done, _ = await asyncio.wait(
requests_in_progress,
return_when=asyncio.FIRST_COMPLETED)
for _ in range(pipeline_parallel_size):
await asyncio.sleep(0)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
......
......@@ -173,6 +173,7 @@ class LLMEngine:
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
......@@ -195,6 +196,7 @@ class LLMEngine:
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
......@@ -296,7 +298,11 @@ class LLMEngine:
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
......@@ -513,8 +519,16 @@ class LLMEngine:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
def process_model_inputs(
self,
......@@ -684,7 +698,8 @@ class LLMEngine:
>>> # abort the request
>>> engine.abort_request(request_id)
"""
self.scheduler.abort_seq_group(request_id)
for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
......@@ -696,11 +711,20 @@ class LLMEngine:
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
return sum(scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler)
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
return any(scheduler.has_unfinished_seqs()
for scheduler in self.scheduler)
def has_unfinished_requests_for_virtual_engine(
self, virtual_engine: int) -> bool:
"""
Returns True if there are unfinished requests for the virtual engine.
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def _process_sequence_group_outputs(
self,
......@@ -749,7 +773,8 @@ class LLMEngine:
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[Union[RequestOutput,
......@@ -815,7 +840,12 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest(
......@@ -886,23 +916,28 @@ class LLMEngine:
# System State
# Scheduler State
num_running_sys = len(self.scheduler.running)
num_swapped_sys = len(self.scheduler.swapped)
num_waiting_sys = len(self.scheduler.waiting)
num_running_sys = sum(
len(scheduler.running) for scheduler in self.scheduler)
num_swapped_sys = sum(
len(scheduler.swapped) for scheduler in self.scheduler)
num_waiting_sys = sum(
len(scheduler.waiting) for scheduler in self.scheduler)
# KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks
gpu_cache_usage_sys = 0.
if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
)
num_free_gpu = sum(
scheduler.block_manager.get_num_free_gpu_blocks()
for scheduler in self.scheduler)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0.
if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
num_free_cpu = sum(
scheduler.block_manager.get_num_free_cpu_blocks()
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Iteration stats
......
......@@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC):
def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
......
......@@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
......@@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
break
if seq.is_finished():
self.scheduler.free_seq(seq)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
......@@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
scheduler: List[Scheduler],
seq_counter: Counter,
stop_checker: StopChecker,
):
......@@ -95,7 +95,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
for scheduler in self.scheduler:
scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
......@@ -133,7 +134,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
......@@ -141,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# Beam search case
......@@ -226,13 +229,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
......@@ -241,7 +246,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
......
......@@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor):
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
async_run_tensor_parallel_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
......@@ -138,17 +138,17 @@ class DistributedGPUExecutor(GPUExecutor):
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
raise NotImplementedError
......
import asyncio
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
......@@ -110,6 +111,30 @@ class ExecutorBase(ABC):
class ExecutorAsyncBase(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
# This locks each pipeline parallel stage so multiple virtual engines
# can't execute on the same stage at the same time
self.pp_locks = [
asyncio.Lock()
for _ in range(parallel_config.pipeline_parallel_size)
]
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, vision_language_config,
speculative_config)
@abstractmethod
async def execute_model_async(
self,
......
......@@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase):
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
def _create_worker(self,
......
......@@ -91,17 +91,17 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
if max_concurrent_workers:
......@@ -114,7 +114,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
for worker in self.workers
]
if async_run_remote_workers_only:
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
......
......@@ -62,7 +62,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
if (self.parallel_config.tensor_parallel_size == 1
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
......@@ -189,6 +190,26 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
for tp_rank in range(self.parallel_config.tensor_parallel_size):
rank = (pp_rank *
self.parallel_config.tensor_parallel_size) + tp_rank
if rank == 0:
pass
elif rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[rank - 1])
else:
self.non_driver_workers.append(self.workers[rank - 1])
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
......@@ -204,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
......@@ -215,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
Args:
- async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
......@@ -228,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers)
count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
......@@ -242,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
......@@ -319,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
async def _run_task_with_lock(task, lock, *args, **kwargs):
async with lock:
return await task(*args, **kwargs)
tasks = []
tasks.append(
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req)))
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
......@@ -29,7 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.arctic import ArcticConfig
logger = init_logger(__name__)
......@@ -426,6 +426,7 @@ class ArcticForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA
......@@ -338,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
......@@ -286,6 +286,7 @@ class BloomForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -25,7 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA
......@@ -365,6 +365,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -46,7 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
@torch.compile
......@@ -353,6 +353,7 @@ class CohereForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
......@@ -381,6 +381,7 @@ class DbrxForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class DeepseekMLP(nn.Module):
......@@ -387,6 +387,7 @@ class DeepseekForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class DeepseekV2MLP(nn.Module):
......@@ -475,6 +475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment