Unverified Commit 6646c0c7 authored by labAxiaoming's avatar labAxiaoming Committed by GitHub
Browse files

[Opt] Optimize deepstack buffer handling for multimodal Qwen3 models (#40145)


Signed-off-by: default avatarxiaoming <1259730330@qq.com>
parent 95995bbe
......@@ -1753,6 +1753,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
for _ in range(self.deepstack_num_level)
]
# Tracks the valid token span currently stored in the buffer.
# Zero means there is no active deepstack payload to consume.
self.deepstack_input_embeds_num_tokens = 0
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
......@@ -1773,6 +1776,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return None
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested more deepstack tokens than available in buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
......@@ -1804,15 +1814,25 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx]
)
self.deepstack_input_embeds_num_tokens = num_tokens
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested to clear more deepstack tokens than available in "
"buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
self.deepstack_input_embeds_num_tokens = 0
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
......
......@@ -1675,6 +1675,9 @@ class Qwen3VLForConditionalGeneration(
)
for _ in range(self.deepstack_num_level)
]
# Tracks the valid token span currently stored in the buffer.
# Zero means there is no active deepstack payload to consume.
self.deepstack_input_embeds_num_tokens = 0
with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM(
......@@ -1702,6 +1705,13 @@ class Qwen3VLForConditionalGeneration(
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return None
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested more deepstack tokens than available in buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
......@@ -1733,15 +1743,25 @@ class Qwen3VLForConditionalGeneration(
self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx]
)
self.deepstack_input_embeds_num_tokens = num_tokens
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
if getattr(self, "deepstack_input_embeds_num_tokens", 0) == 0:
return
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
if num_tokens > self.deepstack_input_embeds_num_tokens:
raise ValueError(
"Requested to clear more deepstack tokens than available in "
"buffer: "
f"{num_tokens=} > {self.deepstack_input_embeds_num_tokens=}"
)
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
self.deepstack_input_embeds_num_tokens = 0
# -- SupportsEncoderCudaGraph protocol methods --
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment