# -------------------------------------------------------- # InternVL # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import argparse import time import torch from mmcv.cnn import get_model_complexity_info from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string from models.intern_vit_6b import InternViT6B from tqdm import tqdm parser = argparse.ArgumentParser(description='Hyperparams') parser.add_argument('config', nargs='?', type=str, default=None) args = parser.parse_args() configs = { 'a': { 'embed_dim': 3968, 'num_heads': 62, 'mlp_ratio': 4, 'depth': 32 }, 'e': { 'embed_dim': 3200, 'num_heads': 50, 'mlp_ratio': 4, 'depth': 48 }, 'f': { 'embed_dim': 3200, 'num_heads': 25, 'mlp_ratio': 4, 'depth': 48 }, 'g': { 'embed_dim': 2496, 'num_heads': 39, 'mlp_ratio': 8, 'depth': 48 }, 'i': { 'embed_dim': 2816, 'num_heads': 44, 'mlp_ratio': 4, 'depth': 64 }, 'm': { 'embed_dim': 2496, 'num_heads': 39, 'mlp_ratio': 4, 'depth': 80 }, } def sa_flops(h, w, dim): return 2 * h * w * h * w * dim def get_flops(model, input_shape): flops, params = get_model_complexity_info(model, input_shape, as_strings=False) _, H, W = input_shape print(flops, params) for i in range(model.depth): flops += sa_flops(H // model.patch_size, W // model.patch_size, model.embed_dim) return flops_to_string(flops), params_to_string(params) if __name__ == '__main__': input_shape = (3, 224, 224) config = configs[args.config] print(config) model = InternViT6B(in_chans=3, patch_size=14, img_size=224, pretrain_size=224, qkv_bias=False, drop_path_rate=0.0, embed_dim=config['embed_dim'], num_heads=config['num_heads'], mlp_ratio=config['mlp_ratio'], init_values=0.1, qk_normalization=True, depth=config['depth'], use_flash_attn=True, with_cp=True, freeze_vit=True, cls_target='cls_patch_concat', num_classes=0, attn_pool_num_heads=16, clip_embed_dim=768, head_norm_type='bn').to(torch.bfloat16) for k, v in model.named_parameters(): v.requires_grad = True if torch.cuda.is_available(): model.cuda() model.eval() flops, params = get_flops(model, input_shape) split_line = '=' * 30 print(f'{split_line}\nInput shape: {input_shape}\n' f'Flops: {flops}\nParams: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') image = torch.rand(128, 3, 224, 224).to(torch.bfloat16).cuda() torch.cuda.synchronize() start_time = time.time() with torch.no_grad(): for i in tqdm(range(10)): out = model(image) torch.cuda.synchronize() end_time = time.time() print('warmup time: ', end_time - start_time) torch.cuda.synchronize() start_time = time.time() with torch.no_grad(): for i in tqdm(range(50)): out = model(image) torch.cuda.synchronize() end_time = time.time() print('using time: ', (end_time - start_time)) print('FPS: ', 50 * 128 / (end_time - start_time)) print(config)