Unverified Commit 105d1045 authored by ltd0924's avatar ltd0924 Committed by GitHub
Browse files

[StepVL] support close img patch (#32923)


Signed-off-by: default avatarluotingdan <luotingdan@stepfun.com>
Signed-off-by: default avatarltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: default avatarluotingdan <luotingdan@stepfun.com>
parent 566cdb6c
...@@ -142,8 +142,11 @@ class Step3VisionProcessor: ...@@ -142,8 +142,11 @@ class Step3VisionProcessor:
class ImagePatcher: class ImagePatcher:
def __init__(self, enable_patch: bool = True) -> None:
self.enable_patch = enable_patch
def determine_window_size(self, long: int, short: int) -> int: def determine_window_size(self, long: int, short: int) -> int:
if long <= 728: if long < 728:
return short if long / short > 1.5 else 0 return short if long / short > 1.5 else 0
return min(short, 504) if long / short > 4 else 504 return min(short, 504) if long / short > 4 else 504
...@@ -241,7 +244,7 @@ class ImagePatcher: ...@@ -241,7 +244,7 @@ class ImagePatcher:
window_size = self.determine_window_size( window_size = self.determine_window_size(
max(img_height, img_width), min(img_height, img_width) max(img_height, img_width), min(img_height, img_width)
) )
if window_size == 0: if window_size == 0 or not self.enable_patch:
return 0, 0 return 0, 0
else: else:
img_width, img_height = self.get_image_size_for_crop( img_width, img_height = self.get_image_size_for_crop(
...@@ -277,7 +280,7 @@ class ImagePatcher: ...@@ -277,7 +280,7 @@ class ImagePatcher:
max(new_img_height, new_img_width), min(new_img_height, new_img_width) max(new_img_height, new_img_width), min(new_img_height, new_img_width)
) )
if window_size == 0: if window_size == 0 or not self.enable_patch:
return img, [], None return img, [], None
else: else:
new_img_width, new_img_height = self.get_image_size_for_crop( new_img_width, new_img_height = self.get_image_size_for_crop(
...@@ -327,7 +330,6 @@ class Step3VLProcessor: ...@@ -327,7 +330,6 @@ class Step3VLProcessor:
self.config = config self.config = config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.image_size = 728 self.image_size = 728
self.patch_size = 504 self.patch_size = 504
self.image_preprocessor = Step3VisionProcessor( self.image_preprocessor = Step3VisionProcessor(
...@@ -340,7 +342,10 @@ class Step3VLProcessor: ...@@ -340,7 +342,10 @@ class Step3VLProcessor:
self.image_feature_placeholder = self.image_token * self.num_image_feature_size self.image_feature_placeholder = self.image_token * self.num_image_feature_size
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
self.patcher = ImagePatcher() # Respect vision config switch to enable/disable patch extraction.
# For video understanding, it's preferable to disable patch.
enable_patch = getattr(self.config.vision_config, "enable_patch", True)
self.patcher = ImagePatcher(enable_patch=enable_patch)
@property @property
def image_token_id(self) -> int: def image_token_id(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment