Commit 87d06573 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

fix tbo to support pipeline-parallel

See merge request dcutoolkit/deeplearing/vllm!130
parents 5e078c69 1a906ab9
......@@ -77,6 +77,9 @@ async def serve_http(app: FastAPI,
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
finish_two_batch_overlap()
return server.shutdown()
finally:
watchdog_task.cancel()
......
......@@ -256,6 +256,10 @@ def _run_worker_process(
and not tunable.record_untuned_is_enabled()):
tunable.write_file()
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
finish_two_batch_overlap()
logger.info("Worker exiting")
......
......@@ -91,13 +91,6 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
else:
request_ids_to_seq_ids_right[key] = value
counter += 1
seq_groups_left = None
seq_groups_right = None
if model_input.sampling_metadata.seq_groups is not None:
seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
previous_hidden_states_left = None
previous_hidden_states_right = None
......@@ -310,14 +303,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=previous_hidden_states_left,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_left,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
sampling_metadata=None, #TBO does not require sampling_stetadata
is_prompt=model_input.is_prompt,
)
model_input_right = ModelInputForGPUWithSamplingMetadata(
......@@ -338,14 +324,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=previous_hidden_states_right,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_right,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
sampling_metadata=None, #TBO does not require sampling_stetadata
is_prompt=model_input.is_prompt,
)
return model_input_left, model_input_right
......@@ -4,8 +4,10 @@ import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
from vllm.logger import init_logger
......@@ -87,16 +89,20 @@ class TwoBatchOverlap():
profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid)
model_kwargs = None
intermediate_tensors = None
if is_left_thread:
model_kwargs = self.model_kwargs_left
intermediate_tensors = self.intermediate_tensors_left
else:
model_kwargs = self.model_kwargs_right
intermediate_tensors = self.intermediate_tensors_right
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=self.intermediate_tensors,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**self.seqlen_agnostic_kwargs,
......@@ -132,7 +138,8 @@ class TwoBatchOverlap():
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
intermediate_tensors_left,
intermediate_tensors_right,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
......@@ -143,7 +150,8 @@ class TwoBatchOverlap():
self.vllm_config = vllm_config
self.virtual_engine = virtual_engine
self.model_executable = model_executable
self.intermediate_tensors = intermediate_tensors
self.intermediate_tensors_left = intermediate_tensors_left
self.intermediate_tensors_right = intermediate_tensors_right
self.multi_modal_kwargs = multi_modal_kwargs
self.self_device = self_device
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
......@@ -204,6 +212,12 @@ def tbo_all_reduce(obj):
return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
output_map = {}
for key in states_left.tensors:
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
output = IntermediateTensors(output_map)
else:
output = torch.concat([states_left, states_right], dim=0)
return output
......@@ -252,12 +266,24 @@ def tbo_model_executable(
model_kwargs_left = model_kwargs.copy()
model_kwargs_right = model_kwargs.copy()
intermediate_tensors_left = None
intermediate_tensors_right = None
if "previous_hidden_states" in model_kwargs:
previous_hidden_states = model_kwargs["previous_hidden_states"]
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
if intermediate_tensors != None:
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
intermediate_tensors_left = {}
intermediate_tensors_right = {}
for key in intermediate_tensors.tensors:
split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
intermediate_tensors_left[key] = split_intermediate_tensors[0]
intermediate_tensors_right[key] = split_intermediate_tensors[1]
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream()
......@@ -268,7 +294,8 @@ def tbo_model_executable(
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
intermediate_tensors_left,
intermediate_tensors_right,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment