Unverified Commit 6772bb0f authored by Yuanyuan Chen's avatar Yuanyuan Chen Committed by GitHub
Browse files

Remove unnecessary CUDA sync of qwen image and video preprocess (#22792)


Signed-off-by: default avatarcyy <cyyever@outlook.com>
Signed-off-by: default avatarYuanyuan Chen <cyyever@outlook.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent fceafaf5
...@@ -976,10 +976,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -976,10 +976,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
(merge_size * merge_size)).tolist()
return image_embeds.split(sizes.tolist()) return image_embeds.split(sizes)
def _process_video_input( def _process_video_input(
self, self,
...@@ -998,9 +1000,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -998,9 +1000,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
(merge_size * merge_size)).tolist()
return video_embeds.split(sizes.tolist()) return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {} mm_input_by_modality = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment