Commit 87d06573 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

fix tbo to support pipeline-parallel

See merge request dcutoolkit/deeplearing/vllm!130
parents 5e078c69 1a906ab9
...@@ -77,6 +77,9 @@ async def serve_http(app: FastAPI, ...@@ -77,6 +77,9 @@ async def serve_http(app: FastAPI,
"port %s is used by process %s launched with command:\n%s", "port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline())) port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.") logger.info("Shutting down FastAPI HTTP server.")
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
finish_two_batch_overlap()
return server.shutdown() return server.shutdown()
finally: finally:
watchdog_task.cancel() watchdog_task.cancel()
......
...@@ -256,6 +256,10 @@ def _run_worker_process( ...@@ -256,6 +256,10 @@ def _run_worker_process(
and not tunable.record_untuned_is_enabled()): and not tunable.record_untuned_is_enabled()):
tunable.write_file() tunable.write_file()
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
finish_two_batch_overlap()
logger.info("Worker exiting") logger.info("Worker exiting")
......
...@@ -91,13 +91,6 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -91,13 +91,6 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
else: else:
request_ids_to_seq_ids_right[key] = value request_ids_to_seq_ids_right[key] = value
counter += 1 counter += 1
seq_groups_left = None
seq_groups_right = None
if model_input.sampling_metadata.seq_groups is not None:
seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
previous_hidden_states_left = None previous_hidden_states_left = None
previous_hidden_states_right = None previous_hidden_states_right = None
...@@ -310,14 +303,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -310,14 +303,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
async_callback=model_input.async_callback, async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs, scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=previous_hidden_states_left, previous_hidden_states=previous_hidden_states_left,
sampling_metadata=SamplingMetadata( sampling_metadata=None, #TBO does not require sampling_stetadata
seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_left,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt, is_prompt=model_input.is_prompt,
) )
model_input_right = ModelInputForGPUWithSamplingMetadata( model_input_right = ModelInputForGPUWithSamplingMetadata(
...@@ -338,14 +324,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -338,14 +324,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
async_callback=model_input.async_callback, async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs, scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=previous_hidden_states_right, previous_hidden_states=previous_hidden_states_right,
sampling_metadata=SamplingMetadata( sampling_metadata=None, #TBO does not require sampling_stetadata
seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_right,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt, is_prompt=model_input.is_prompt,
) )
return model_input_left, model_input_right return model_input_left, model_input_right
...@@ -4,8 +4,10 @@ import queue ...@@ -4,8 +4,10 @@ import queue
import threading import threading
import torch import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -87,16 +89,20 @@ class TwoBatchOverlap(): ...@@ -87,16 +89,20 @@ class TwoBatchOverlap():
profile.ProfRangePush('start') profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid) self.tbo_thread_synchronize(tid)
model_kwargs = None model_kwargs = None
intermediate_tensors = None
if is_left_thread: if is_left_thread:
model_kwargs = self.model_kwargs_left model_kwargs = self.model_kwargs_left
intermediate_tensors = self.intermediate_tensors_left
else: else:
model_kwargs = self.model_kwargs_right model_kwargs = self.model_kwargs_right
intermediate_tensors = self.intermediate_tensors_right
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine): self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable( hidden_or_intermediate_states = self.model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
intermediate_tensors=self.intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device), device=self.self_device),
**self.seqlen_agnostic_kwargs, **self.seqlen_agnostic_kwargs,
...@@ -132,7 +138,8 @@ class TwoBatchOverlap(): ...@@ -132,7 +138,8 @@ class TwoBatchOverlap():
vllm_config, vllm_config,
virtual_engine, virtual_engine,
model_executable, model_executable,
intermediate_tensors, intermediate_tensors_left,
intermediate_tensors_right,
multi_modal_kwargs, multi_modal_kwargs,
self_device, self_device,
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
...@@ -143,7 +150,8 @@ class TwoBatchOverlap(): ...@@ -143,7 +150,8 @@ class TwoBatchOverlap():
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.virtual_engine = virtual_engine self.virtual_engine = virtual_engine
self.model_executable = model_executable self.model_executable = model_executable
self.intermediate_tensors = intermediate_tensors self.intermediate_tensors_left = intermediate_tensors_left
self.intermediate_tensors_right = intermediate_tensors_right
self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_kwargs = multi_modal_kwargs
self.self_device = self_device self.self_device = self_device
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
...@@ -204,6 +212,12 @@ def tbo_all_reduce(obj): ...@@ -204,6 +212,12 @@ def tbo_all_reduce(obj):
return tensor_model_parallel_all_reduce(obj) return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right): def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
output_map = {}
for key in states_left.tensors:
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
output = IntermediateTensors(output_map)
else:
output = torch.concat([states_left, states_right], dim=0) output = torch.concat([states_left, states_right], dim=0)
return output return output
...@@ -252,12 +266,24 @@ def tbo_model_executable( ...@@ -252,12 +266,24 @@ def tbo_model_executable(
model_kwargs_left = model_kwargs.copy() model_kwargs_left = model_kwargs.copy()
model_kwargs_right = model_kwargs.copy() model_kwargs_right = model_kwargs.copy()
intermediate_tensors_left = None
intermediate_tensors_right = None
if "previous_hidden_states" in model_kwargs: if "previous_hidden_states" in model_kwargs:
previous_hidden_states = model_kwargs["previous_hidden_states"] previous_hidden_states = model_kwargs["previous_hidden_states"]
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])] query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0) split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0] model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1] model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
if intermediate_tensors != None:
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
intermediate_tensors_left = {}
intermediate_tensors_right = {}
for key in intermediate_tensors.tensors:
split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
intermediate_tensors_left[key] = split_intermediate_tensors[0]
intermediate_tensors_right[key] = split_intermediate_tensors[1]
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
tbo_obj.step_event.record() tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -268,7 +294,8 @@ def tbo_model_executable( ...@@ -268,7 +294,8 @@ def tbo_model_executable(
vllm_config, vllm_config,
virtual_engine, virtual_engine,
model_executable, model_executable,
intermediate_tensors, intermediate_tensors_left,
intermediate_tensors_right,
multi_modal_kwargs, multi_modal_kwargs,
self_device, self_device,
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment