Commit 9e4126a0 authored by chenych's avatar chenych
Browse files

add .float()

parent 88df767f
...@@ -532,10 +532,10 @@ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix): ...@@ -532,10 +532,10 @@ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix):
checkpoint_model = checkpoint[model_key] checkpoint_model = checkpoint[model_key]
print("Load state_dict by model_key = %s" % model_key) print("Load state_dict by model_key = %s" % model_key)
break break
if checkpoint_model is None: if checkpoint_model is None:
checkpoint_model = checkpoint checkpoint_model = checkpoint
state_dict = model.state_dict() state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']: for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
...@@ -571,7 +571,7 @@ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix): ...@@ -571,7 +571,7 @@ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix):
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated # only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2).float()
pos_tokens = torch.nn.functional.interpolate( pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
...@@ -850,7 +850,7 @@ def dump_predictions(args, result, file_suffix): ...@@ -850,7 +850,7 @@ def dump_predictions(args, result, file_suffix):
each_file = os.path.join(args.task_cache_path, f"submit_{i}_{file_suffix}.json") each_file = os.path.join(args.task_cache_path, f"submit_{i}_{file_suffix}.json")
with open(each_file, "r") as fp: with open(each_file, "r") as fp:
jsons += json.load(fp) jsons += json.load(fp)
new_jsons = [] new_jsons = []
res_dict = dict() res_dict = dict()
if args.task in ["coco_captioning", "nocaps"]: if args.task in ["coco_captioning", "nocaps"]:
...@@ -869,7 +869,7 @@ def dump_predictions(args, result, file_suffix): ...@@ -869,7 +869,7 @@ def dump_predictions(args, result, file_suffix):
os.remove(output_file) os.remove(output_file)
else: else:
jsons = result jsons = result
result_file = os.path.join(args.output_dir, f"submit_{file_suffix}.json") result_file = os.path.join(args.output_dir, f"submit_{file_suffix}.json")
if jsons is not None: if jsons is not None:
with open(result_file, "w") as fp: with open(result_file, "w") as fp:
...@@ -891,10 +891,10 @@ def coco_caption_eval(gt_dir, results_file, split): ...@@ -891,10 +891,10 @@ def coco_caption_eval(gt_dir, results_file, split):
filenames = {'coco_captioning_val':'coco_karpathy_val_gt.json', filenames = {'coco_captioning_val':'coco_karpathy_val_gt.json',
'coco_captioning_test':'coco_karpathy_test_gt.json', 'coco_captioning_test':'coco_karpathy_test_gt.json',
'nocaps_val':'nocaps_val_gt.json'} 'nocaps_val':'nocaps_val_gt.json'}
download_url(urls[split], gt_dir) download_url(urls[split], gt_dir)
annotation_file = os.path.join(gt_dir, filenames[split]) annotation_file = os.path.join(gt_dir, filenames[split])
# create coco object and coco_result object # create coco object and coco_result object
coco = COCO(annotation_file) coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file) coco_result = coco.loadRes(results_file)
...@@ -905,9 +905,9 @@ def coco_caption_eval(gt_dir, results_file, split): ...@@ -905,9 +905,9 @@ def coco_caption_eval(gt_dir, results_file, split):
# evaluate results # evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching # SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate() coco_eval.evaluate()
res_dict = dict() res_dict = dict()
for metric, score in coco_eval.eval.items(): for metric, score in coco_eval.eval.items():
res_dict[metric] = score res_dict[metric] = score
return res_dict return res_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment