Unverified Commit 65f1d065 authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

[Bug] Fix Intern-S1 model accuracy and support /generate interface with input_ids (#12367)

parent 9434a0e5
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -50,8 +50,6 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -50,8 +50,6 @@ class InternS1ForConditionalGeneration(nn.Module):
(image_size // patch_size) ** 2 * (config.downsample_ratio**2) (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
) )
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
self.ps_version = getattr(config, "ps_version", "v1")
# self.template = getattr(config, 'template', 'internvl2_5')
config.vision_config.use_flash_attn = True if use_flash_attn else False config.vision_config.use_flash_attn = True if use_flash_attn else False
config.text_config._attn_implementation = ( config.text_config._attn_implementation = (
...@@ -59,7 +57,6 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -59,7 +57,6 @@ class InternS1ForConditionalGeneration(nn.Module):
) )
logger.info(f"num_image_token: {self.num_image_token}") logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"ps_version: {self.ps_version}")
self.vision_model = InternVisionModel(config.vision_config) self.vision_model = InternVisionModel(config.vision_config)
if config.text_config.architectures[0] == "Qwen2ForCausalLM": if config.text_config.architectures[0] == "Qwen2ForCausalLM":
...@@ -104,13 +101,7 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -104,13 +101,7 @@ class InternS1ForConditionalGeneration(nn.Module):
int(w * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)), int(c / (scale_factor * scale_factor)),
) )
if self.ps_version == "v1": x = x.permute(0, 2, 1, 3).contiguous()
logger.warn(
"In ps_version 'v1', the height and width have not been swapped back, "
"which results in a transposed image."
)
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x return x
def extract_feature(self, pixel_values): def extract_feature(self, pixel_values):
...@@ -224,7 +215,6 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -224,7 +215,6 @@ class InternS1ForConditionalGeneration(nn.Module):
) )
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
...@@ -280,13 +270,5 @@ class InternS1ForConditionalGeneration(nn.Module): ...@@ -280,13 +270,5 @@ class InternS1ForConditionalGeneration(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = [InternS1ForConditionalGeneration] EntryClass = InternS1ForConditionalGeneration
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
...@@ -598,7 +598,6 @@ class InternVLChatModel(nn.Module): ...@@ -598,7 +598,6 @@ class InternVLChatModel(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
...@@ -678,22 +677,5 @@ class InternVLChatModel(nn.Module): ...@@ -678,22 +677,5 @@ class InternVLChatModel(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
# Skip params that are created by quantization wrappers and are not expected in the ckpt
_quant_only_fragments = (
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
)
unloaded_params = {
n
for n in unloaded_params
if not any(frag in n for frag in _quant_only_fragments)
}
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = InternVLChatModel EntryClass = InternVLChatModel
...@@ -231,7 +231,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -231,7 +231,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
pixel_values = torch.cat(pixel_values, dim=0) pixel_values = torch.cat(pixel_values, dim=0)
original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>" original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
input_text = base_output.input_text.replace(
self.IMG_CONTEXT_TOKEN, original_placeholder
)
input_text_updated = input_text input_text_updated = input_text
for num_patches in num_patches_list: for num_patches in num_patches_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment