Unverified Commit 65f1d065 authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

[Bug] Fix Intern-S1 model accuracy and support /generate interface with input_ids (#12367)

parent 9434a0e5
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -50,8 +50,6 @@ class InternS1ForConditionalGeneration(nn.Module):
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.downsample_ratio = config.downsample_ratio
self.ps_version = getattr(config, "ps_version", "v1")
# self.template = getattr(config, 'template', 'internvl2_5')
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.text_config._attn_implementation = (
......@@ -59,7 +57,6 @@ class InternS1ForConditionalGeneration(nn.Module):
)
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"ps_version: {self.ps_version}")
self.vision_model = InternVisionModel(config.vision_config)
if config.text_config.architectures[0] == "Qwen2ForCausalLM":
......@@ -104,13 +101,7 @@ class InternS1ForConditionalGeneration(nn.Module):
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v1":
logger.warn(
"In ps_version 'v1', the height and width have not been swapped back, "
"which results in a transposed image."
)
else:
x = x.permute(0, 2, 1, 3).contiguous()
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
......@@ -224,7 +215,6 @@ class InternS1ForConditionalGeneration(nn.Module):
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
......@@ -280,13 +270,5 @@ class InternS1ForConditionalGeneration(nn.Module):
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = [InternS1ForConditionalGeneration]
EntryClass = InternS1ForConditionalGeneration
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union
import torch
......@@ -598,7 +598,6 @@ class InternVLChatModel(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
......@@ -678,22 +677,5 @@ class InternVLChatModel(nn.Module):
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
# Skip params that are created by quantization wrappers and are not expected in the ckpt
_quant_only_fragments = (
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
)
unloaded_params = {
n
for n in unloaded_params
if not any(frag in n for frag in _quant_only_fragments)
}
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = InternVLChatModel
......@@ -231,7 +231,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
pixel_values = torch.cat(pixel_values, dim=0)
original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
input_text = base_output.input_text.replace(
self.IMG_CONTEXT_TOKEN, original_placeholder
)
input_text_updated = input_text
for num_patches in num_patches_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment