Unverified Commit a41357e9 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[VLM] Improve consistency between feature size calculation and dummy data for profiling (#6146)

parent ae96ef8f
...@@ -37,6 +37,9 @@ _KEYS_TO_MODIFY_MAPPING = { ...@@ -37,6 +37,9 @@ _KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model", "language_model.model": "language_model",
} }
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
...@@ -128,13 +131,11 @@ def get_llava_next_image_feature_size( ...@@ -128,13 +131,11 @@ def get_llava_next_image_feature_size(
def get_max_llava_next_image_tokens(ctx: InputContext): def get_max_llava_next_image_tokens(ctx: InputContext):
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448
return get_llava_next_image_feature_size( return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig), ctx.get_hf_config(LlavaNextConfig),
input_height=dummy_height, input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=dummy_width, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
) )
...@@ -142,13 +143,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): ...@@ -142,13 +143,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaNextConfig) hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
# Result in the max possible feature size (2x2 grid of 336x336px tiles) image_feature_size = get_max_llava_next_image_tokens(ctx)
dummy_height = dummy_width = 448
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=dummy_height,
input_width=dummy_width,
)
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
...@@ -160,8 +155,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): ...@@ -160,8 +155,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
vision_config, vision_config,
image_width_override=dummy_width, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=dummy_height, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return seq_data, mm_data return seq_data, mm_data
......
...@@ -53,6 +53,10 @@ _KEYS_TO_MODIFY_MAPPING = { ...@@ -53,6 +53,10 @@ _KEYS_TO_MODIFY_MAPPING = {
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044 _IMAGE_TOKEN_ID = 32044
# Result in the max possible feature size (h:w = 16:1)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
hidden_act="quick_gelu", hidden_act="quick_gelu",
hidden_size=1024, hidden_size=1024,
...@@ -322,24 +326,17 @@ def get_phi3v_image_feature_size( ...@@ -322,24 +326,17 @@ def get_phi3v_image_feature_size(
def get_max_phi3v_image_tokens(ctx: InputContext): def get_max_phi3v_image_tokens(ctx: InputContext):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
return get_phi3v_image_feature_size( return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig), ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height, input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=dummy_width, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
) )
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50 image_feature_size = get_max_phi3v_image_tokens(ctx)
image_feature_size = get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip( seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
...@@ -349,8 +346,8 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): ...@@ -349,8 +346,8 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
) )
mm_data = dummy_image_for_clip( mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG, CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=dummy_height, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
) )
return seq_data, mm_data return seq_data, mm_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment