Unverified Commit 8a466026 authored by Chatcharin Sangbutsarakum's avatar Chatcharin Sangbutsarakum Committed by GitHub
Browse files

[Model] Remove unnecessary CUDA sync of GLM-4.1V image and video preprocess (#24332)


Signed-off-by: default avatarWin <chatcharinsang@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 61aa4b29
......@@ -1429,6 +1429,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
......@@ -1443,13 +1444,15 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw.tolist())
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
(merge_size * merge_size)).tolist()
return image_embeds.split(sizes)
def _process_video_input(
self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
......@@ -1466,8 +1469,9 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw=grid_thw.tolist())
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
(merge_size * merge_size)).tolist()
return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment