Commit fe6cdd2e authored by zhe chen's avatar zhe chen
Browse files

Update huggingface model


Update huggingface model


Update README.md


Update README.md


Update README.md


Update huggingface model


Update huggingface model
parent 3bd2e7b9
import torch
from PIL import Image
from transformers import (AutoConfig, AutoModel,
AutoModelForImageClassification, CLIPImageProcessor)
def test_model(model_name):
print('model_name:', model_name)
image_path = 'img_1.png'
image_processor = CLIPImageProcessor.from_pretrained(model_name)
image = Image.open(image_path)
image = image_processor(images=image, return_tensors='pt').pixel_values
print('image shape:', image.shape)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
output = model(image)
for k, v in output.items():
if type(v) == list:
for idx, item in enumerate(v):
print(f'{k}_{idx} shape: {item.shape}')
elif v is None:
print(f'{k} is None')
else:
print(f'{k} shape: {v.shape}')
print('------------------------')
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
output = model(image)
for k, v in output.items():
if type(v) == list:
for idx, item in enumerate(v):
print(f'{k}_{idx} shape: {item.shape}')
elif v is None:
print(f'{k} is None')
else:
print(f'{k} shape: {v.shape}')
logits = output['logits']
argmax = int(torch.argmax(logits, dim=1))
print(argmax)
test_model('./22k_model/internimage_l_22k_384')
test_model('./22k_model/internimage_xl_22k_384')
test_model('./22k_model/internimage_h_jointto22k_384')
test_model('./22k_model/internimage_g_jointto22k_384')
test_model('./in1k_model/internimage_t_1k_224')
test_model('./in1k_model/internimage_s_1k_224')
test_model('./in1k_model/internimage_b_1k_224')
test_model('./in1k_model/internimage_l_22kto1k_384')
test_model('./in1k_model/internimage_xl_22kto1k_384')
test_model('./in1k_model/internimage_h_22kto1k_640')
test_model('./in1k_model/internimage_g_22kto1k_512')
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
......@@ -231,7 +231,7 @@ class AttentionPoolingBlock(AttentiveBlock):
class StemLayer(nn.Module):
r""" Stem layer of InternImage
r"""Stem layer of InternImage
Args:
in_chans (int): number of input channels
out_chans (int): number of output channels
......@@ -271,7 +271,7 @@ class StemLayer(nn.Module):
class DownsampleLayer(nn.Module):
r""" Downsample layer of InternImage
r"""Downsample layer of InternImage
Args:
channels (int): number of input channels
norm_layer (str): normalization layer
......@@ -295,7 +295,7 @@ class DownsampleLayer(nn.Module):
class MLPLayer(nn.Module):
r""" MLP layer of InternImage
r"""MLP layer of InternImage
Args:
in_features (int): number of input features
hidden_features (int): number of hidden features
......@@ -328,7 +328,7 @@ class MLPLayer(nn.Module):
class InternImageLayer(nn.Module):
r""" Basic layer of InternImage
r"""Basic layer of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
......@@ -432,7 +432,7 @@ class InternImageLayer(nn.Module):
class InternImageBlock(nn.Module):
r""" Block of InternImage
r"""Block of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
......@@ -526,7 +526,7 @@ class InternImageBlock(nn.Module):
class InternImage(nn.Module):
r""" InternImage
r"""InternImage
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
https://arxiv.org/pdf/2103.14030
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment