Commit 385eeae9 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

fix tbo support deepseek mtp

See merge request dcutoolkit/deeplearing/vllm!123
parents 2b8700e0 7124e74d
......@@ -99,6 +99,13 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
previous_hidden_states_left = None
previous_hidden_states_right = None
if model_input.previous_hidden_states != None:
split_previous_hidden_states = torch.split(model_input.previous_hidden_states, batch_size_split, dim=0)
previous_hidden_states_left = split_previous_hidden_states[0]
previous_hidden_states_right = split_previous_hidden_states[1]
if isinstance(model_input.attn_metadata, MLACommonMetadata):
attn_metadata_left = MLACommonMetadata(
num_prefills = num_prefills_left,
......@@ -302,7 +309,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
previous_hidden_states=previous_hidden_states_left,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left,
......@@ -330,7 +337,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
previous_hidden_states=previous_hidden_states_right,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right,
......
......@@ -86,6 +86,11 @@ class TwoBatchOverlap():
break
profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid)
model_kwargs = None
if is_left_thread:
model_kwargs = self.model_kwargs_left
else:
model_kwargs = self.model_kwargs_right
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
......@@ -95,7 +100,7 @@ class TwoBatchOverlap():
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**self.seqlen_agnostic_kwargs,
**self.model_kwargs,
**model_kwargs,
)
if is_left_thread:
self.sem_right.release()
......@@ -131,7 +136,8 @@ class TwoBatchOverlap():
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs):
model_kwargs_left,
model_kwargs_right):
if self.left_thread == None:
self.init_tbo_thread()
self.vllm_config = vllm_config
......@@ -141,7 +147,8 @@ class TwoBatchOverlap():
self.multi_modal_kwargs = multi_modal_kwargs
self.self_device = self_device
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
self.model_kwargs = model_kwargs
self.model_kwargs_left = model_kwargs_left
self.model_kwargs_right = model_kwargs_right
self.model_input_left_queue.put(model_input_left)
self.model_input_right_queue.put(model_input_right)
......@@ -242,6 +249,16 @@ def tbo_model_executable(
batch_size_right += 1
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
model_kwargs_left = model_kwargs
model_kwargs_right = model_kwargs
if "previous_hidden_states" in model_kwargs:
previous_hidden_states = model_kwargs["previous_hidden_states"]
batch_size_split = [batch_size_left, batch_size_right]
split_previous_hidden_states = torch.split(previous_hidden_states, batch_size_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream):
......@@ -255,7 +272,8 @@ def tbo_model_executable(
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs)
model_kwargs_left,
model_kwargs_right)
tbo_obj.all_reduce()
states_left, states_right = tbo_obj.get_model_output()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment