Unverified Commit 6646c0c7 authored by labAxiaoming's avatar labAxiaoming Committed by GitHub
Browse files

[Opt] Optimize deepstack buffer handling for multimodal Qwen3 models (#40145)


Signed-off-by: default avatarxiaoming <1259730330@qq.com>
parent 95995bbe
...@@ -1753,6 +1753,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1753,6 +1753,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
for _ in range(self.deepstack_num_level) for _ in range(self.deepstack_num_level)
] ]
# Tracks the valid token span currently stored in the buffer.
# Zero means there is no active deepstack payload to consume.
self.deepstack_input_embeds_num_tokens = 0
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM( self.language_model = Qwen3MoeLLMForCausalLM(
...@@ -1773,6 +1776,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1773,6 +1776,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) -> IntermediateTensors | None: ) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None): if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped return None # If vision tower is skipped
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return None
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested more deepstack tokens than available in buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
# get deepstack_input_embeds from buffer, and clear the buffer # get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors( return IntermediateTensors(
...@@ -1804,15 +1814,25 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1804,15 +1814,25 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self.deepstack_input_embeds[idx][:num_tokens].copy_( self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx] deepstack_input_embeds[idx]
) )
self.deepstack_input_embeds_num_tokens = num_tokens
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None): if not getattr(self, "deepstack_input_embeds", None):
return return
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return
# clear deepstack_input_embeds in buffer # clear deepstack_input_embeds in buffer
if num_tokens > 0: if num_tokens > 0:
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested to clear more deepstack tokens than available in "
"buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
for idx in range(self.deepstack_num_level): for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_() self.deepstack_input_embeds[idx][:num_tokens].zero_()
self.deepstack_input_embeds_num_tokens = 0
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {} mm_input_by_modality = {}
......
...@@ -1675,6 +1675,9 @@ class Qwen3VLForConditionalGeneration( ...@@ -1675,6 +1675,9 @@ class Qwen3VLForConditionalGeneration(
) )
for _ in range(self.deepstack_num_level) for _ in range(self.deepstack_num_level)
] ]
# Tracks the valid token span currently stored in the buffer.
# Zero means there is no active deepstack payload to consume.
self.deepstack_input_embeds_num_tokens = 0
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM( self.language_model = Qwen3LLMForCausalLM(
...@@ -1702,6 +1705,13 @@ class Qwen3VLForConditionalGeneration( ...@@ -1702,6 +1705,13 @@ class Qwen3VLForConditionalGeneration(
) -> IntermediateTensors | None: ) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None): if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped return None # If vision tower is skipped
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return None
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested more deepstack tokens than available in buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
# get deepstack_input_embeds from buffer, and clear the buffer # get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors( return IntermediateTensors(
...@@ -1733,15 +1743,25 @@ class Qwen3VLForConditionalGeneration( ...@@ -1733,15 +1743,25 @@ class Qwen3VLForConditionalGeneration(
self.deepstack_input_embeds[idx][:num_tokens].copy_( self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx] deepstack_input_embeds[idx]
) )
self.deepstack_input_embeds_num_tokens = num_tokens
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None): if not getattr(self, "deepstack_input_embeds", None):
return return
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return
# clear deepstack_input_embeds in buffer # clear deepstack_input_embeds in buffer
if num_tokens > 0: if num_tokens > 0:
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested to clear more deepstack tokens than available in "
"buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
for idx in range(self.deepstack_num_level): for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_() self.deepstack_input_embeds[idx][:num_tokens].zero_()
self.deepstack_input_embeds_num_tokens = 0
# -- SupportsEncoderCudaGraph protocol methods -- # -- SupportsEncoderCudaGraph protocol methods --
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment