Unverified Commit 69bff9bc authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

fix float16 support for kimi-vl (#17156)


Co-authored-by: default avatarzhouzaida <zhouzaida@msh.team>
parent 41ca7eb4
......@@ -340,8 +340,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
else:
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size)
# fp32 -> bf16
pixel_values = pixel_values.to(torch.bfloat16)
pixel_values = pixel_values.to(self.vision_tower.dtype)
# image_grid_hws.shape = (N, 2)
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment