Commit 385eeae9 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

fix tbo support deepseek mtp

See merge request dcutoolkit/deeplearing/vllm!123
parents 2b8700e0 7124e74d
...@@ -98,6 +98,13 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -98,6 +98,13 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:] seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1 selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1 selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
previous_hidden_states_left = None
previous_hidden_states_right = None
if model_input.previous_hidden_states != None:
split_previous_hidden_states = torch.split(model_input.previous_hidden_states, batch_size_split, dim=0)
previous_hidden_states_left = split_previous_hidden_states[0]
previous_hidden_states_right = split_previous_hidden_states[1]
if isinstance(model_input.attn_metadata, MLACommonMetadata): if isinstance(model_input.attn_metadata, MLACommonMetadata):
attn_metadata_left = MLACommonMetadata( attn_metadata_left = MLACommonMetadata(
...@@ -302,7 +309,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -302,7 +309,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
virtual_engine=model_input.virtual_engine, virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback, async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs, scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states, previous_hidden_states=previous_hidden_states_left,
sampling_metadata=SamplingMetadata( sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_left, seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left, selected_token_indices=selected_token_indices_left,
...@@ -330,7 +337,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -330,7 +337,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
virtual_engine=model_input.virtual_engine, virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback, async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs, scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states, previous_hidden_states=previous_hidden_states_right,
sampling_metadata=SamplingMetadata( sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_right, seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right, selected_token_indices=selected_token_indices_right,
......
...@@ -86,6 +86,11 @@ class TwoBatchOverlap(): ...@@ -86,6 +86,11 @@ class TwoBatchOverlap():
break break
profile.ProfRangePush('start') profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid) self.tbo_thread_synchronize(tid)
model_kwargs = None
if is_left_thread:
model_kwargs = self.model_kwargs_left
else:
model_kwargs = self.model_kwargs_right
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine): self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable( hidden_or_intermediate_states = self.model_executable(
...@@ -95,7 +100,7 @@ class TwoBatchOverlap(): ...@@ -95,7 +100,7 @@ class TwoBatchOverlap():
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device), device=self.self_device),
**self.seqlen_agnostic_kwargs, **self.seqlen_agnostic_kwargs,
**self.model_kwargs, **model_kwargs,
) )
if is_left_thread: if is_left_thread:
self.sem_right.release() self.sem_right.release()
...@@ -131,7 +136,8 @@ class TwoBatchOverlap(): ...@@ -131,7 +136,8 @@ class TwoBatchOverlap():
multi_modal_kwargs, multi_modal_kwargs,
self_device, self_device,
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs): model_kwargs_left,
model_kwargs_right):
if self.left_thread == None: if self.left_thread == None:
self.init_tbo_thread() self.init_tbo_thread()
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -141,7 +147,8 @@ class TwoBatchOverlap(): ...@@ -141,7 +147,8 @@ class TwoBatchOverlap():
self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_kwargs = multi_modal_kwargs
self.self_device = self_device self.self_device = self_device
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
self.model_kwargs = model_kwargs self.model_kwargs_left = model_kwargs_left
self.model_kwargs_right = model_kwargs_right
self.model_input_left_queue.put(model_input_left) self.model_input_left_queue.put(model_input_left)
self.model_input_right_queue.put(model_input_right) self.model_input_right_queue.put(model_input_right)
...@@ -242,6 +249,16 @@ def tbo_model_executable( ...@@ -242,6 +249,16 @@ def tbo_model_executable(
batch_size_right += 1 batch_size_right += 1
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right) model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
model_kwargs_left = model_kwargs
model_kwargs_right = model_kwargs
if "previous_hidden_states" in model_kwargs:
previous_hidden_states = model_kwargs["previous_hidden_states"]
batch_size_split = [batch_size_left, batch_size_right]
split_previous_hidden_states = torch.split(previous_hidden_states, batch_size_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
tbo_obj.step_event.record() tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream): with torch.cuda.stream(tbo_step_stream):
...@@ -255,7 +272,8 @@ def tbo_model_executable( ...@@ -255,7 +272,8 @@ def tbo_model_executable(
multi_modal_kwargs, multi_modal_kwargs,
self_device, self_device,
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs) model_kwargs_left,
model_kwargs_right)
tbo_obj.all_reduce() tbo_obj.all_reduce()
states_left, states_right = tbo_obj.get_model_output() states_left, states_right = tbo_obj.get_model_output()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment