Commit fef630ee authored by suily's avatar suily
Browse files

init

parents
Pipeline #1942 failed with stages
in 0 seconds
#!/bin/bash
MODEL_VERSION=vicuna-v1-5-7b
gpu_vis=0,1 # per_device_train_batch_size * gradient_accumulation_steps * n_gpus = 128
MASTER_PORT=29029
deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT vtimellm/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--lora_enable True \
--model_name_or_path ./checkpoints/vicuna-7b-v1.5 \
--version v1 \
--data_path ./data/stage2.json \
--feat_folder ./feat/intern_clip_feat \
--pretrain_mm_mlp_adapter ./checkpoints/vtimellm-$MODEL_VERSION-stage1_test/mm_projector.bin \
--output_dir ./checkpoints/vtimellm-$MODEL_VERSION-stage2_test \
--bf16 True \
--num_train_epochs 2 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--learning_rate 1e-4 \
--freeze_mm_mlp_adapter True \
--lora_r 64 \
--lora_alpha 128 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
\ No newline at end of file
#!/bin/bash
MODEL_VERSION=chatglm3-6b
gpu_vis=0 # per_device_train_batch_size * gradient_accumulation_steps * n_gpus = 128
MASTER_PORT=29570
deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT vtimellm/train/train.py \
--deepspeed ./scripts/zero3.json \
--lora_enable True \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version plain \
--data_path ./data/stage2_chinese.json \
--feat_folder /path/to/stage2_feat \
--pretrain_mm_mlp_adapter ./checkpoints/vtimellm-$MODEL_VERSION-stage1/mm_projector.bin \
--output_dir ./checkpoints/vtimellm-$MODEL_VERSION-stage2 \
--bf16 True \
--num_train_epochs 2 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 1e-4 \
--freeze_mm_mlp_adapter True \
--lora_r 64 \
--lora_alpha 128 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
#!/bin/bash
MODEL_VERSION=vicuna-v1-5-7b
gpu_vis=0,1 # per_device_train_batch_size * gradient_accumulation_steps * n_gpus = 128
MASTER_PORT=29029
deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT vtimellm/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--lora_enable True \
--training_stage 3 \
--model_name_or_path ./checkpoints/vicuna-7b-v1.5 \
--version v1 \
--data_path ./data/stage3.json \
--feat_folder ./feat/stage3_clip_feat \
--pretrain_mm_mlp_adapter ./checkpoints/vtimellm-$MODEL_VERSION-stage1_test/mm_projector.bin \
--stage2_path ./checkpoints/vtimellm-$MODEL_VERSION-stage2_test \
--output_dir ./checkpoints/vtimellm-$MODEL_VERSION-stage3_test \
--bf16 True \
--num_train_epochs 2 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--learning_rate 1e-4 \
--freeze_mm_mlp_adapter True \
--lora_r 64 \
--lora_alpha 128 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
\ No newline at end of file
#!/bin/bash
MODEL_VERSION=chatglm3-6b
gpu_vis=0 # per_device_train_batch_size * gradient_accumulation_steps * n_gpus = 128
MASTER_PORT=29570
deepspeed --include localhost:$gpu_vis --master_port $MASTER_PORT vtimellm/train/train.py \
--deepspeed ./scripts/zero2.json \
--lora_enable True \
--training_stage 3 \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version plain \
--data_path ./data/stage3_chinese.json \
--feat_folder /path/to/stage3_feat \
--pretrain_mm_mlp_adapter ./checkpoints/vtimellm-$MODEL_VERSION-stage1/mm_projector.bin \
--stage2_path ./checkpoints/vtimellm-$MODEL_VERSION-stage2 \
--output_dir ./checkpoints/vtimellm-$MODEL_VERSION-stage3 \
--bf16 True \
--num_train_epochs 2 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 1e-4 \
--freeze_mm_mlp_adapter True \
--lora_r 64 \
--lora_alpha 128 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": 1e5,
"wall_clock_breakdown": false
}
\ No newline at end of file
from .model import VTimeLLMLlamaForCausalLM
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<video>"
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("<image>", "").strip()
if 'mmtag' in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "<image>\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0: message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode == "Crop":
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
ret.append([img_str, None])
msg = msg.replace('<image>', '').strip()
if len(msg) > 0:
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_llava_llama_2 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mpt = Conversation(
system="""<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_llava_plain = Conversation(
system="",
roles=("", ""),
messages=(
),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_llava_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!"),
("Assistant", "Hi there! How can I help you today?")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_llava_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llava_v1_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
version="v1_mmtag",
)
default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
"llava_v0": conv_llava_v0,
"v0_mmtag": conv_llava_v0_mmtag,
"llava_v1": conv_llava_v1,
"v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"mpt": conv_mpt,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
"""
Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
"""
import argparse
import os
root_dir = os.path.join(os.getcwd(), "..")
import sys
sys.path.append(root_dir)
import torch
import gradio as gr
import decord
decord.bridge.set_bridge('torch')
from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from vtimellm.conversation import conv_templates, SeparatorStyle
from vtimellm.model.builder import load_pretrained_model
from vtimellm.utils import disable_torch_init
from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor
from PIL import Image
from transformers import TextStreamer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
from PIL import Image
BICUBIC = Image.BICUBIC
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
import clip
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument("--model_base", type=str, required=True, help="Path to your vicuna-7b-v1.5 huggingface checkpoint")
parser.add_argument("--clip_path", type=str, default=os.path.join(root_dir, "checkpoints/clip/ViT-L-14.pt"))
parser.add_argument("--pretrain_mm_mlp_adapter", type=str, default=os.path.join(root_dir, "checkpoints/vtimellm-vicuna-v1-5-7b-stage1/mm_projector.bin"))
parser.add_argument("--stage2", type=str, default=os.path.join(root_dir, "checkpoints/vtimellm-vicuna-v1-5-7b-stage2"))
parser.add_argument("--stage3", type=str, default=os.path.join(root_dir, "checkpoints/vtimellm-vicuna-v1-5-7b-stage3"))
parser.add_argument("--share", action='store_true')
args = parser.parse_args()
return args
# ========================================
# Model Initialization
# ========================================
args = parse_args()
device = f'cuda:{args.gpu_id}'
disable_torch_init()
tokenizer, model, context_len = load_pretrained_model(args, args.stage2, args.stage3)
model = model.to(torch.float16).to(device)
clip_model, _ = clip.load(args.clip_path)
clip_model.eval()
clip_model = clip_model.to(device)
transform = Compose([
Resize(224, interpolation=BICUBIC),
CenterCrop(224),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
print('Initialization Finished')
# ========================================
# Gradio Setting
# ========================================
TEXT_PLACEHOLDER = 'Upload your video first, or directly click the examples at the bottom of the page.'
def gradio_reset(chat_state, video_features_state, conv_state):
if chat_state is not None:
chat_state.messages = []
video_features_state = None
conv_state = {}
return None, gr.update(value=None, interactive=True), gr.update(value='', placeholder=TEXT_PLACEHOLDER, interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, video_features_state, conv_state
def upload_video(gr_video, chat_state, video_features_state, conv_state, chatbot):
if not gr_video:
return None, None, gr.update(interactive=True), chat_state, video_features_state, conv_state, None
else:
print(gr_video)
video_loader = VideoExtractor(N=100)
_, images = video_loader.extract({'id': None, 'video': gr_video})
transform = Compose([
Resize(224, interpolation=BICUBIC),
CenterCrop(224),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
# print(images.shape) # <N, 3, H, W>
images = transform(images / 255.0)
images = images.to(torch.float16)
with torch.no_grad():
video_features_state = clip_model.encode_image(images.to('cuda'))
chatbot = chatbot + [((gr_video,), None)]
chat_state = conv_templates["v1"].copy()
conv_state['first'] = True
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, video_features_state, conv_state, chatbot
def gradio_ask(user_message, chatbot, chat_state, conv_state):
if len(user_message) == 0:
conv_state['should_answer'] = False
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state, conv_state
conv_state['should_answer'] = True
chatbot = chatbot + [[user_message, None]]
if conv_state['first']:
user_message = DEFAULT_IMAGE_TOKEN + '\n' + user_message
conv_state['first'] = False
chat_state.append_message(chat_state.roles[0], user_message)
chat_state.append_message(chat_state.roles[1], None)
return '', chatbot, chat_state, conv_state
def gradio_answer(chatbot, chat_state, video_features_state, conv_state, temperature):
if not conv_state['should_answer']:
return chatbot, chat_state
prompt = chat_state.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
stop_str = chat_state.sep if chat_state.sep_style != SeparatorStyle.TWO else chat_state.sep2 # plain:sep(###) v1:sep2(None)
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=video_features_state[None,].to(device),
do_sample=True,
temperature=temperature,
max_new_tokens=1024,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria]
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
chat_state.messages[-1][-1] = outputs
chatbot[-1][1] = outputs
print(chat_state.get_prompt())
print(chat_state)
return chatbot, chat_state
with gr.Blocks() as demo:
gr.Markdown('''# Demo for VTimeLLM''')
with gr.Row():
with gr.Column(scale=0.5):
video = gr.Video()
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Reset")
temperature = gr.Slider(
minimum=0,
maximum=2.0,
value=0.05,
step=0.01,
interactive=True,
label="Temperature",
)
with gr.Column():
chat_state = gr.State()
video_features_state = gr.State()
conv_state = gr.State({})
chatbot = gr.Chatbot(label='VTimeLLM')
text_input = gr.Textbox(label='User', placeholder=TEXT_PLACEHOLDER, interactive=False)
with gr.Column():
gr.Examples(examples=[
[os.path.join(root_dir, f"images/demo.mp4"), "Explain why the video is funny."],
], inputs=[video, text_input])
upload_button.click(upload_video, [video, chat_state, video_features_state, conv_state, chatbot], [video, text_input, upload_button, chat_state, video_features_state, conv_state, chatbot])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state, conv_state], [text_input, chatbot, chat_state, conv_state]).then(gradio_answer, [chatbot, chat_state, video_features_state, conv_state, temperature], [chatbot, chat_state])
clear.click(gradio_reset, [chat_state, video_features_state, conv_state], [chatbot, video, text_input, upload_button, chat_state, video_features_state, conv_state], queue=False)
demo.queue().launch(share=args.share)
{
"v_bXdq2zI1Ms0": {
"duration": 73.1,
"timestamps": [
[6.94, 69.08],
[37.28, 43.49],
[43.13, 55.55]
],
"sentences": ["Three men are standing on a mat.", " The man in front begins to do karate on the mat.", " He gets down on the ground and flips around."]
},
"v_CN01Gm2Yc4k": {
"duration": 17.56,
"timestamps": [
[0, 5],
[5, 12.2],
[12.2, 17.56]
],
"sentences": ["A young lady is gripping a black and silver punching bag between her legs.", "Once she has secured herself on the bag,she begins doing a set of crunches by pulling herself up.", "In between the crunches,she sits up and makes punches out into the air,before going back down."]
}
}
\ No newline at end of file
SOFTWARE LICENSE AGREEMENT FOR EVALUATION
This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software ("User(s)"), and Nippon Telegraph and Telephone corporation ("NTT").
READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE.
BACKGROUND
A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and related documentation listed in Exhibit A to this Agreement.
B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement.
C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement.
In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows:
1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper submitted by NTT to a certain academy. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1.
2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software.
3. Term. This Agreement is effective whichever is earlier (i) upon User's acceptance of the Agreement, or (ii) upon User's installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by Userfs decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof, or to destroy all such materials and provide written verification of such destruction to NTT.
4. Proprietary Rights
(a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights, copyrights and trade secret rights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software.
(b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i)?SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT; (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE; OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE.
(c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied.
5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE.
6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT.
7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARDLESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3.
8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent.
9. General
(a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect.
(b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter.
(c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User.
(d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding.
(e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof.
(f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT's obligation set forth under this Agreement due to any cause beyond NTTfs reasonable control.
EXHIBIT A
# SODA
This repository is the imprimentation of "SODA: Story Oriented Dense Video Captioning Evaluation Flamework" published at ECCV 2020 [pdf](https://fujiso.github.io/publications/ECCV2020_soda.pdf).
SODA measures the performance of video story description systems.
## Update
v1.1 (2021/5)
* Added new option "--multi_reference" to deal with multiple reference.
SODA selects the reference that has the maximum f1 for each video, and returns macro averaged scores.
* Fixed BertScore import
## Requirements
python 3.6+ (developed with 3.7)
* Numpy
* tqdm
* [pycocoevalcap (Python3 version)](https://github.com/salaniz/pycocoevalcap)
* BERTScore (optional)
## Usage
You can run SODA by specifying the path of system output and that of ground truth.
Both files should be the json format for ActivityNet Captions.
```bash
python soda.py -s path/to/submission.json -r path/to/ground_truth.json
```
You can run on the multiple reference setting, with `--multi_reference` option.
```bash
python soda.py --multi_reference -s path/to/submission.json -r path/to/ground_truth1.json path/to/ground_truth2.json
```
You can try other sentence evaluation metrics, e.g. CIDEr and BERTScore, with `-m` option.
```bash
python soda.py -s path/to/submission.json -m BERTScore
```
## Sample input file
Please use the same format as [ActivityNet Challenge](http://activity-net.org/index.html)
```
{
version: "VERSION 1.0",
results: {
"sample_id" : [
{
sentence: "This is a sample caption.",
timestamp: [1.23, 4.56]
},
{
sentence: "This is a sample caption 2.",
timestamp: [7.89, 19.87]
}
]
}
external_data: {
used: False,
}
}
```
## Reference
```
@inproceedings{Fujita2020soda,
title={SODA: Story Oriented Dense Video Captioning Evaluation Flamework},
author={Soichiro Fujita and Tsutomu Hirao and Hidetaka Kamigaito and Manabu Okumura and Masaaki Nagata},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
month={August},
year={2020},
}
```
## LICENSE
NTT License
According to the license, it is not allowed to create pull requests.
Please feel free to send issues.
import numpy as np
import json
from collections import defaultdict
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from .utils import iou, remove_nonascii
class ANETCaptions:
def __init__(self, preds, gts, gt_vid, verbose=False):
self.pred_keys = ['results']
# self.pred_keys = ['results', 'version', 'external_data']
self.verbose = verbose
self.preds = preds
self.gts = gts
self.gt_vids = gt_vid
self.tokenizer = PTBTokenizer()
@classmethod
def from_load_files(cls, gt_file, pred_file, multi_reference=True, verbose=False):
gts, gt_vid = cls.load_ground_truth(gt_file, multi_reference=multi_reference, verbose=verbose)
preds = cls.load_prediction(pred_file, verbose=verbose)
# missing video
gt_vid = [x for x in gt_vid if x in preds]
gt_vid = cls.check_videos(gt_vid, preds.keys(),verbose=verbose)
return cls(preds, gts, gt_vid, verbose=verbose)
@classmethod
def from_prediction(cls, gt_file, preds, multi_reference=True, verbose=False):
results = {}
for vid in preds['results']:
results[vid] = sorted(preds["results"][vid], key=lambda x: x["timestamp"][0])
gts, gt_vid = cls.load_ground_truth(gt_file, multi_reference=multi_reference)
gt_vid = cls.check_videos(gt_vid, results.keys(),verbose=verbose)
return cls(results, gts, gt_vid, verbose=verbose)
@staticmethod
def load_ground_truth(filenames, multi_reference=False, verbose=False):
if verbose:
print(f"| Loading ground truths: {filenames}.")
if isinstance(filenames, str):
filenames = [filenames]
gt_vids = set()
gt = defaultdict(dict)
gts = []
for filename in filenames:
if isinstance(filename, dict):
_gt = filename
else:
with open(filename, "r") as f:
_gt = json.load(f)
gt_vids.update(_gt.keys())
gts.append(_gt)
if multi_reference is False:
for vid in gt_vids:
t, s = [], []
for _g in gts:
if vid not in _g:
continue
t += _g[vid]["timestamps"]
s += _g[vid]["sentences"]
sort_t, sort_s = list(zip(*sorted(zip(t, s), key=lambda x: x[0][0])))
gt[vid]["timestamps"] = sort_t
gt[vid]["sentences"] = sort_s
gts = [gt]
if verbose:
print(f"stats:\n\t n_files: {len(filenames)}, n_videos: {len(gt_vids)}")
return gts, gt_vids
@staticmethod
def load_prediction(filename, verbose=False):
if verbose: print(f"\n| Loading predictions: {filename}.")
if isinstance(filename, dict):
pred = filename
else:
with open(filename, 'r') as f:
pred = json.load(f)
# If the json file doesn’t have enough attribute
# if not all([key in pred.keys() for key in ["results"]]):
# raise IOError('Please input a correct format prediction file.')
results = {}
for vid in pred['results']:
# if vid not in self.gt_vids: continue
results[vid] = sorted(pred["results"][vid], key=lambda x: x["timestamp"][0])
return results
def preprocess(self):
if self.verbose: print("\n| Preprocessing captions...")
n_ref = len(self.gts)
p_spliter = [0]
g_spliter = [[0] for i in range(n_ref)]
times = {}
cur_preds = {}
cur_gts = [{} for i in range(n_ref)]
for i, vid in enumerate(self.gt_vids):
cur_preds.update({j+p_spliter[-1]:[{"caption": remove_nonascii(p["sentence"])}] for j,p in enumerate(self.preds[vid])})
times[i] = [p["timestamp"] for p in self.preds[vid]]
p_spliter.append(p_spliter[-1] + len(times[i]))
for n in range(n_ref):
if vid not in self.gts[n]:
g_spliter[n].append(g_spliter[n][-1])
continue
cur_gts[n].update({j+g_spliter[n][-1]:[{"caption": remove_nonascii(p)}] for j,p in enumerate(self.gts[n][vid]["sentences"])})
g_spliter[n].append(g_spliter[n][-1] + len(self.gts[n][vid]["sentences"]))
tokenize_preds = self.tokenizer.tokenize(cur_preds)
tokenize_gts = [self.tokenizer.tokenize(j) for j in cur_gts]
for i, vid in enumerate(self.gt_vids):
_p = [tokenize_preds[j] for j in range(p_spliter[i],p_spliter[i+1])]
self.preds[vid] = {"timestamps":times[i], "sentences":_p}
for n in range(n_ref):
if vid not in self.gts[n]: continue
_g = [tokenize_gts[n][j] for j in range(g_spliter[n][i],g_spliter[n][i+1])]
self.gts[n][vid]["sentences"] = _g
@staticmethod
def check_videos(gold_vid, pred_vid, verbose=True):
not_appear = set(gold_vid) - set(pred_vid)
if len(not_appear) > 0 and verbose:
print((f"Warning: some videos in ground truth file are not appeared in prediction file!\n"
f"\t{len(not_appear)} videos are not predicted: {not_appear}"))
return list(set(gold_vid) & set(pred_vid))
#!/usr/bin/env python
from bert_score.scorer import BERTScorer
class BertScore:
# def __init__(self, lang="en", model_type="bert-large-uncased"):
def __init__(self, lang="en", model_type=None):
self.lang = lang
self.model_type = model_type
self.bert = BERTScorer(model_type=model_type, lang=lang)
def compute_score(self, gts, res):
assert gts.keys() == res.keys()
# convert dict to list of str
cands = list(map(lambda x: x[0], res.values()))
refs = list(map(lambda x: x[0], gts.values()))
(P, R, F), hashname = self.bert.score(cands, refs, return_hash=True)
# print(f'{hashname}: P={P.mean().item():.6f} R={R.mean().item():.6f} F={F.mean().item():.6f}')
F = F.numpy()
return F.mean(), F
def method(self):
return "BertScore"
#!/usr/bin/env python
from bert_score.scorer import BERTScorer
class BertScore:
def __init__(self, lang="en", model_type="roberta-large"):
self.lang = lang
self.model_type = model_type
self.bert = BERTScorer(model_type=model_type, lang=lang)
def compute_score(self, gts, res):
assert gts.keys() == res.keys()
# convert dict to list of str
cands = list(map(lambda x: x[0], res.values()))
refs = list(map(lambda x: x[0], gts.values()))
(P, R, F), hashname = self.bert.score(cands, refs, return_hash=True)
#print(f'{hashname}: P={P.mean().item():.6f} R={R.mean().item():.6f} F={F.mean().item():.6f}')
R = R.numpy()
return R.mean(), R
def method(self):
return "BertScore"
#!/usr/bin/env python
import numpy as np
#from moverscore_v2 import get_idf_dict, word_mover_score
from moverscore import get_idf_dict, word_mover_score
from collections import defaultdict
class MoverScore:
def __init__(self, lang="en", model_type=None):
self.lang = lang
self.model_type=model_type
#self.model = load_model(model_type=model_type, lang=lang)
self.idf_dict_ref = None
self.idf_dict_hyp = None
def compute_score(self, gts, res):
assert gts.keys()==res.keys()
assert self.idf_dict_hyp is not None and self.idf_dict_hyp is not None
# convert dict to list of str
cands = list(map(lambda x:x[0], res.values()))
refs = list(map(lambda x:x[0], gts.values()))
scores = word_mover_score(refs, cands, self.idf_dict_ref, self.idf_dict_hyp, \
stop_words=[], n_gram=1, remove_subwords=True)
#print(np.mean(scores), max(scores))
return np.mean(scores), scores
def make_dict(self, all_gts, all_res, vids):
gold = []
pred = []
for vid in vids:
gold.extend(all_gts[vid]["sentences"])
pred.extend([pred["sentence"] for pred in all_res[vid]])
self.idf_dict_ref = get_idf_dict(gold)
self.idf_dict_hyp = get_idf_dict(pred)
#print(self.idf_dict_hyp)
def method(self):
return "MoverScore"
#!/uer/bin/env python
import argparse
import json
from tqdm import tqdm
import numpy as np
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.cider.cider import Cider
from .dataset import ANETCaptions
from .utils import iou, remove_nonascii
class SODA:
def __init__(self, data, soda_type="c", tious=None, scorer="Meteor", verbose=False):
#self.data = data
self.preds = data.preds
self.gts = data.gts
self.gt_vids = data.gt_vids
self.soda_type = soda_type
self.tious = [0.0] if tious is None else tious
self.tokenizer = PTBTokenizer()
if scorer == "BertScore":
from nlpeval.bert_r_score import BertScore
self.scorer = eval(scorer)()
self.scorer_name = scorer
self.verbose = verbose
if soda_type == "a": # averaging F-measure scores with IoU threshold = 0.9, 0.7, 0.5, 0.3
self.soda_func = self.soda_a
elif soda_type == "b": # F-measure, where IoU threshold is set to 0.
self.soda_func = self.soda_b
elif soda_type == "c": # F-measure, utilizing the IoU x METEOR score
self.soda_func = self.soda_c
elif soda_type == "d": # F-measure of IoU score
self.soda_func = self.soda_d
class Dummy:
def compute_score(self, x, y):
return [0, 0]
self.scorer = Dummy()
else:
raise NotImplementedError
@classmethod
def build(cls, preds, gts, gt_vids, soda_type="c", tious=[0.0], scorer="Meteor", verbose=False):
data = ANETCaptions(preds, gts, gt_vids)
data.preprocess()
return cls(data, soda_type, tious, scorer, verbose)
@classmethod
def build_from_prediction(cls, preds, gt_files, soda_type="c", tious=[0.0], scorer="Meteor", verbose=False):
data = ANETCaptions.from_prediction(gt_files, preds)
data.preprocess()
return cls(data, soda_type, tious, scorer, verbose)
def calc_iou_matrix(self, preds, golds):
#print(preds["timestamps"], gt["timestamps"])
return np.array([[iou(pred, ct) for pred in preds["timestamps"]] for ct in golds['timestamps']])
def calc_score_matrix(self, preds, golds):
# Reformat to fit the input of pycocoevalcap scorers.
p_sent, g_sent = preds["sentences"], golds["sentences"]
res = {index: p for index, p in enumerate(p_sent)}
gts = [{index: g for index in range(len(p_sent))} for i, g in enumerate(g_sent)]
return np.array([self.scorer.compute_score(res, gt)[1] for gt in gts])
def evaluate(self,):
if self.verbose:
print(f"\n| Running SODA {self.soda_type}.")
tious = self.tious
p_best = [[] for i in range(len(tious))]
r_best = [[] for i in range(len(tious))]
f_best = [[] for i in range(len(tious))]
n_pred = []
for vid in tqdm(self.gt_vids, disable=not self.verbose):
_p = [[] for i in range(len(tious))]
_r = [[] for i in range(len(tious))]
_f = [[] for i in range(len(tious))]
pred = self.preds[vid]
n_pred.append(len(pred["sentences"]))
# empty pred
if not pred['sentences']:
for i, tiou in enumerate(tious):
p_best[i].append(0)
r_best[i].append(0)
f_best[i].append(0)
continue
for gt in self.gts:
if vid not in gt:
continue
gold = gt[vid]
# create matrix
_iou = self.calc_iou_matrix(pred, gold)
scores = self.calc_score_matrix(pred, gold)
for i, tiou in enumerate(tious):
iou = np.copy(_iou)
iou[iou < tiou] = 0.0
try:
max_score, pairs = self.soda_func(iou, scores)
except: # RecursionError
max_score, pairs = 0., None
(n_g, n_p) = iou.shape
p = max_score / n_p
r = max_score / n_g
_p[i].append(p)
_r[i].append(r)
_f[i].append(2 * p * r / (p + r) if p+r > 0 else 0)
best_idx = np.argmax(_f, axis=1)
for i, tiou in enumerate(tious):
p_best[i].append(_p[i][best_idx[i]])
r_best[i].append(_r[i][best_idx[i]])
f_best[i].append(_f[i][best_idx[i]])
precision = np.mean(p_best, axis=1)
recall = np.mean(r_best, axis=1)
f1 = np.mean(f_best, axis=1)
print(f"avg. outputs: {np.mean(n_pred)}")
# average scores across all the tIoUs
if self.verbose:
for i, tiou in enumerate(tious):
partial_result = {self.scorer_name: [precision[i], recall[i], f1[i]]}
print_score(partial_result, description=f"tIoU: {tiou}")
final_scores = [np.mean(precision), np.mean(recall), np.mean(f1)]
result = {self.scorer_name: final_scores}
return result
def soda_a(self, iou, scores):
_, pairs = self.chased_dp_assignment(iou)
r, c = (*zip(*pairs),)
max_score = np.sum(scores[r, c])
return max_score, pairs
def soda_b(self, iou, scores):
# same as soda_a
_, pairs = self.chased_dp_assignment(iou)
r, c = (*zip(*pairs),)
max_score = np.sum(scores[r, c])
return max_score, pairs
def soda_c(self, iou, scores):
max_score, pairs = self.chased_dp_assignment(iou*scores)
return max_score, pairs
def soda_d(self, iou, scores):
max_score, pairs = self.chased_dp_assignment(iou)
return max_score, pairs
def chased_dp_assignment(self, scores):
"""
Run dp matching
Recurrence:
dp[i,j] =
max(dp[i-1,j], dp[i-1,j-1] + scores[i,j], dp[i,j-1])
"""
M, N = scores.shape
dp = - np.ones((M, N))
path = np.zeros((M, N))
def transition(i, j):
if dp[i, j] >= 0:
return dp[i, j]
elif i == 0 and j == 0:
state = [-1, -1, scores[i, j]]
elif i == 0:
state = [-1, transition(i, j-1), scores[i, j]]
elif j == 0:
state = [transition(i-1, j), -1, scores[i, j]]
else:
state = [transition(i-1, j), transition(i, j-1), transition(i-1, j-1) + scores[i, j]]
dp[i, j] = np.max(state)
path[i, j] = np.argmax(state)
return dp[i, j]
def get_pairs(i, j):
p = np.where(path[i][:j+1] == 2)[0]
if i != 0 and len(p) == 0:
return get_pairs(i-1, j)
elif i == 0 or p[-1] == 0:
return [(i, p[-1])]
else:
return get_pairs(i-1, p[-1]-1) + [(i, p[-1])]
N, M = scores.shape
max_score = transition(N-1, M-1)
pairs = get_pairs(N-1, M-1)
return max_score, pairs
def print_score(result, description="SODA result"):
prf = ["precision", "recall", "f1_score"]
print('-' * 80)
print(description)
print('-' * 80)
for scorer_name, score in result.items():
print(f'| scorer:{scorer_name}')
for k, v in zip(prf, score):
print(f"\t{k}:{v*100:2.4f}")
def main(args):
# Call coco eval
data = ANETCaptions.from_load_files(args.references,
args.prediction,
multi_reference=args.multi_reference,
verbose=args.verbose,
)
data.preprocess()
if args.soda_type == 'a':
tious = args.tious
else:
tious = None
evaluator = SODA(data,
soda_type=args.soda_type,
tious=tious,
scorer=args.metric,
verbose=args.verbose
)
result = evaluator.evaluate()
print_score(result)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--prediction', type=str, required=True, default='sample.json',
help='system output file with json format for ActivityNet Challenge')
parser.add_argument('-r', '--references', type=str, nargs='+', default=['./data/val_1.json', './data/val_2.json'],
help='reference files with ground truth captions')
parser.add_argument('-m', '--metric', type=str, default="Meteor", choices=['Meteor', 'Cider', 'BertScore'],
help='choice evaluation metrics for SODA')
parser.add_argument('-s', '--soda_type', type=str, default="c", choices=['a', 'b', 'c', 'd'],
help='choice evaluation metrics for SODA')
parser.add_argument('--tious', type=float, nargs='+', default=[0.3, 0.5, 0.7, 0.9],
help='list of the tIoUs (only for SODA-a)')
parser.add_argument('-mr', '--multi_reference', action='store_true',
help='print details')
parser.add_argument('-v', '--verbose', action='store_true',
help='print details')
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment