"git@developer.sourcefind.cn:xdb4_94051/vllm.git" did not exist on "c267b1a02c952b68a897c96201f32ad57e0b955e"
Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class MNLIDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
sentence1, sentence2, label = self.dataset[index]
if label == '0':
label = 'maybe'
elif label == '1':
label = 'yes'
elif label == '2':
label = 'no'
else:
raise NotImplementedError
sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(
' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2)
)
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class MRPCDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
sentence1, sentence2, label = self.dataset[index]
if label == '0':
label = 'no'
elif label == '1':
label = 'yes'
else:
raise NotImplementedError
sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(
' does text1 " {} " and text2 " {} " have the same semantics?'.format(sentence1, sentence2),
)
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class QNLIDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
question, sentence, label = self.dataset[index]
if label == '0' or label == 'not_entailment':
label = 'no'
elif label == '1' or label == 'entailment':
label = 'yes'
else:
raise NotImplementedError
question = ' '.join(question.lower().strip().split()[:self.max_src_length])
sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(
' does " {} " contain the answer to question " {} "?'.format(sentence, question)
)
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class QQPDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
question1, question2, label = self.dataset[index]
if label == '0':
label = 'no'
elif label == '1':
label = 'yes'
else:
raise NotImplementedError
question1 = ' '.join(question1.lower().strip().split()[:self.max_src_length])
question2 = ' '.join(question2.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(
' is question " {} " and question " {} " equivalent?'.format(question1, question2)
)
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class RTEDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
sentence1, sentence2, label = self.dataset[index]
if label == 'not_entailment':
label = 'no'
elif label == 'entailment':
label = 'yes'
else:
raise NotImplementedError
sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(
' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2),
)
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class SST2Dataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
sentence, label = self.dataset[index]
if label == '0':
label = 'negative'
elif label == '1':
label = 'positive'
else:
raise NotImplementedError
sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(' is the sentiment of text " {} " positive or negative?'.format(sentence))
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import re
import torch.utils.data
from fairseq.data import FairseqDataset
import string
CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。'
ENGLISH_PUNCTUATION = string.punctuation
logger = logging.getLogger(__name__)
class OFADataset(FairseqDataset):
def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
self.split = split
self.dataset = dataset
self.bpe = bpe
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.bos = src_dict.bos()
self.eos = src_dict.eos()
self.pad = src_dict.pad()
self.bos_item = torch.LongTensor([self.bos])
self.eos_item = torch.LongTensor([self.eos])
def __len__(self):
return len(self.dataset)
def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
s = self.tgt_dict.encode_line(
line=self.bpe.encode(text) if use_bpe else text,
add_if_not_exist=False,
append_eos=False
).long()
if length is not None:
s = s[:length]
if append_bos:
s = torch.cat([self.bos_item, s])
if append_eos:
s = torch.cat([s, self.eos_item])
return s
def pre_question(self, question, max_ques_words=None):
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
question = re.sub(
r"\s{2,}",
' ',
question,
)
question = question.rstrip('\n')
question = question.strip(' ')
# truncate question
question_words = question.split(' ')
if max_ques_words is not None and len(question_words) > max_ques_words:
question = ' '.join(question_words[:max_ques_words])
return question
def pre_caption(self, caption, max_words=None):
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
# truncate caption
caption_words = caption.split(' ')
if max_words is not None and len(caption_words) > max_words:
caption = ' '.join(caption_words[:max_words])
return caption
def pre_chinese(self, text, max_words):
text = text.lower().replace(CHINESE_PUNCTUATION, " ").replace(ENGLISH_PUNCTUATION, " ")
text = re.sub(
r"\s{2,}",
' ',
text,
)
text = text.rstrip('\n')
text = text.strip(' ')[:max_words]
return text
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from io import BytesIO
import math
import logging
import random
import warnings
import numpy as np
import torch
import base64
from torchvision import transforms
from PIL import Image, ImageFile
from data import data_utils
from data.ofa_dataset import OFADataset
from utils.vision_helper import RandomAugment
import utils.transforms as T
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def get_whole_word_mask(bpe, dictionary):
if bpe is not None:
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
if tok.startswith("madeupword"):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(
list(map(is_beginning_of_word, range(len(dictionary))))
)
return mask_whole_words
return None
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
id = np.array([s["id"] for s in samples])
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
code_masks = None
if samples[0].get("code_mask", None) is not None:
code_masks = torch.cat([sample['code_mask'] for sample in samples])
conf = torch.cat([s['conf'] for s in samples], dim=0)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"patch_images": patch_images,
"patch_masks": patch_masks,
"code_masks": code_masks,
"prev_output_tokens": prev_output_tokens
},
"target": target,
"conf": conf
}
return batch
class UnifyDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=128,
max_tgt_length=30,
seed=7,
code_dict_size=8192,
num_bins=1000,
patch_image_size=384,
code_image_size=128,
pure_text_dataset=None,
pure_image_dataset=None,
detection_dataset=None,
all_object_list=None,
all_caption_list=None,
type2ans_dict=None,
ans2type_dict=None,
max_image_size=512,
mask_ratio=0.3,
random_ratio=0.0,
keep_ratio=0.0,
mask_length="span-poisson",
poisson_lambda=3.0,
replace_length=1
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.seed = seed
self.code_dict_size = code_dict_size
self.num_bins = num_bins
self.patch_image_size = patch_image_size
self.code_image_size = code_image_size
self.pure_text_dataset = pure_text_dataset
self.pure_image_dataset = pure_image_dataset
self.detection_dataset = detection_dataset
self.epoch = 0
self.all_object_list = all_object_list
self.all_caption_list = all_caption_list
self.type2ans_dict = type2ans_dict
self.ans2type_dict = ans2type_dict
self.mask_ratio = mask_ratio
self.random_ratio = random_ratio
self.keep_ratio = keep_ratio
self.mask_length = mask_length
self.poisson_lambda = poisson_lambda
self.replace_length = replace_length
if self.replace_length not in [-1, 0, 1]:
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
if self.mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={self.mask_length}")
if self.mask_length == "subword" and self.replace_length not in [0, 1]:
raise ValueError(f"if using subwords, use replace-length=1 or 0")
self.mask_idx = src_dict.index("<mask>")
self.mask_whole_word = (
get_whole_word_mask(self.bpe, self.src_dict)
if self.mask_length != "subword"
else None
)
self.mask_span_distribution = None
if self.mask_length == "span-poisson":
_lambda = self.poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.pos_tgt_item = self.encode_text(" yes")
self.neg_tgt_item = self.encode_text(" no")
self.mask_left = self.mask_top = int(0.5 * self.code_image_size)
self.mask_right = self.mask_bottom = int(1.5 * self.code_image_size)
self.mask_ids = [
i*self.code_image_size*2+j
for i in range(self.code_image_size*2) for j in range(self.code_image_size*2)
if not (self.mask_left <= i < self.mask_right and self.mask_top <= j < self.mask_bottom)
]
scales = np.arange(patch_image_size, 481).tolist()
# for image-text pair
self.patch_resize_transform = transforms.Compose([
T.RandomResize(scales, max_size=672),
transforms.CenterCrop(patch_image_size),
RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# for pure image
self.patch_crop_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# for detection
self.detection_transform = T.Compose([
T.RandomHorizontalFlip(),
T.LargeScaleJitter(output_size=self.code_image_size*2, aug_scale_min=1.0, aug_scale_max=1.5),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_image_size=max_image_size)
])
# for visual grounding
self.visual_grounding_transform = T.Compose([
T.RandomResize(scales, max_size=672),
T.ObjectCenterCrop((patch_image_size, patch_image_size)),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_image_size=max_image_size)
])
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def get_negative_caption(self, caption, gt_objects):
prob = random.random()
if gt_objects is not None and gt_objects != '' and prob > 0.6:
gt_object = random.choice(gt_objects.strip().split('&&'))
negative_object = random.choice(self.all_object_list[:-1])
negative_object = self.all_object_list[-1] if negative_object == gt_object else negative_object
negative_caption = caption.replace(gt_object, negative_object)
else:
negative_caption = random.choice(self.all_caption_list)
return negative_caption
def get_negative_answer(self, answer, conf):
prob = random.random()
if conf > (prob + 0.1) and answer in self.ans2type_dict:
negative_answer_type = self.ans2type_dict[answer]
if negative_answer_type == 'how many' and answer.isdigit() and prob > 0.5:
negative_answer = int(answer) + random.choice([-1, 1]) if answer != 0 else 1
else:
negative_answer_list = self.type2ans_dict[negative_answer_type]
negative_answer = random.choice(negative_answer_list[:-1])
negative_answer = negative_answer_list[-1] if negative_answer == answer else negative_answer
return negative_answer
negative_answer_list = self.type2ans_dict['other']
negative_answer = random.choice(negative_answer_list[:-1])
negative_answer = negative_answer_list[-1] if negative_answer == answer else negative_answer
return negative_answer
def process_image_text_pair(self, index):
uniq_id, image, caption, question, refs, gt_objects, dataset_name, type = self.dataset[index]
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
patch_image = self.patch_resize_transform(image) if type != 'visual_grounding' else None
patch_mask = torch.tensor([True])
conf = torch.tensor([1.0])
if type == 'caption':
tgt_caption = self.pre_caption(caption, self.max_tgt_length)
pos_src_caption = self.pre_caption(caption, self.max_src_length)
neg_src_caption = self.pre_caption(self.get_negative_caption(caption, gt_objects), self.max_src_length)
src_item = self.encode_text(" what does the image describe?")
tgt_item = self.encode_text(" {}".format(tgt_caption))
pos_src_item = self.encode_text(' does the image describe " {} "?'.format(pos_src_caption))
neg_src_item = self.encode_text(' does the image describe " {} "?'.format(neg_src_caption))
elif type == 'qa':
question = self.pre_question(question, self.max_src_length)
ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in refs.split('&&')}
answer = max(ref_dict, key=ref_dict.get)
conf = ref_dict[answer]
src_item = self.encode_text(" {}".format(question))
tgt_item = self.encode_text(" {}".format(answer))
conf = torch.tensor([conf])
pos_src_item = self.encode_text(' what is the answer to question " {} ". is " {} "?'.format(question, answer))
neg_src_item = self.encode_text(
' what is the answer to question " {} ". is " {} "?'.format(question, self.get_negative_answer(answer, conf))
)
elif type == 'visual_grounding':
conf = torch.tensor([1.0])
w, h = image.size
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
x0, y0, x1, y1 = refs.strip().split(',')
boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
boxes_target["labels"] = np.array([0])
boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
patch_image, boxes_target = self.visual_grounding_transform(image, boxes_target)
quant_x0 = "<bin_{}>".format(int((boxes_target["boxes"][0][0] * (self.num_bins - 1)).round()))
quant_y0 = "<bin_{}>".format(int((boxes_target["boxes"][0][1] * (self.num_bins - 1)).round()))
quant_x1 = "<bin_{}>".format(int((boxes_target["boxes"][0][2] * (self.num_bins - 1)).round()))
quant_y1 = "<bin_{}>".format(int((boxes_target["boxes"][0][3] * (self.num_bins - 1)).round()))
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
src_caption = self.pre_caption(caption, self.max_src_length)
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
tgt_item = self.encode_text(region_coord, use_bpe=False)
else:
logger.info('type {} is not implemented'.format(type))
raise NotImplementedError
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
pos_src_item = torch.cat([self.bos_item, pos_src_item, self.eos_item]) if type != 'visual_grounding' else None
neg_src_item = torch.cat([self.bos_item, neg_src_item, self.eos_item]) if type != 'visual_grounding' else None
if type == 'caption' and dataset_name == 'cc12m':
target_item[:2] = self.src_dict.pad()
target_item[-1] = self.eos_item
example = {
"id": uniq_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
}
examples = [example]
prob = random.random()
if type == 'visual_grounding':
region_example = example.copy()
region_prefix_item = self.encode_text(' what does the region describe? region:')
region_coord_item = self.encode_text('{}'.format(region_coord), use_bpe=False)
region_src_item = torch.cat([region_prefix_item, region_coord_item])
region_tgt_item = self.encode_text(' {}'.format(self.pre_caption(caption, self.max_tgt_length)))
region_example["source"] = torch.cat([self.bos_item, region_src_item, self.eos_item])
region_example["target"] = torch.cat([region_tgt_item, self.eos_item])
region_example["prev_output_tokens"] = torch.cat([self.bos_item, region_tgt_item])
region_example["conf"] = torch.tensor([1.0])
examples.append(region_example)
elif prob >= 0.5 and self.split == 'train':
pos_example = example.copy()
pos_example["source"] = pos_src_item
pos_example["target"] = torch.cat([self.pos_tgt_item, self.eos_item])
pos_example["prev_output_tokens"] = torch.cat([self.bos_item, self.pos_tgt_item])
examples.append(pos_example)
elif self.split == 'train':
neg_example = example.copy()
neg_example["source"] = neg_src_item
neg_example["target"] = torch.cat([self.neg_tgt_item, self.eos_item])
neg_example["prev_output_tokens"] = torch.cat([self.bos_item, self.neg_tgt_item])
examples.append(neg_example)
return examples
def process_pure_text(self, index):
patch_image = torch.zeros((3, self.code_image_size*2, self.code_image_size*2))
patch_mask = torch.tensor([False])
code_mask = torch.tensor([False])
conf = torch.tensor([2.0])
examples = []
for _ in range(2):
uniq_id, text = self.pure_text_dataset[index]
text = text.strip().lower()
text_item = self.encode_text(" {}".format(text), length=512)
text_item = text_item[-256:]
text_item = torch.cat([self.bos_item, text_item, self.eos_item])
mask_text_item = self.add_whole_word_mask(text_item.clone(), self.mask_ratio)
prefix_item = self.encode_text(' what is the complete text of " "?')
src_item = torch.cat([prefix_item[:-2], mask_text_item[1:-1], prefix_item[-2:]])
tgt_item = text_item[1:-1]
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
example = {
"id": uniq_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"code_mask": code_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
}
examples.append(example)
return examples
def process_pure_image(self, index):
image_id, image, code = self.pure_image_dataset[index]
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
patch_image = self.patch_crop_transform(image)
patch_image[:, self.mask_top:self.mask_bottom, self.mask_left:self.mask_right] = 0
patch_mask = torch.tensor([True])
src_item = self.encode_text(" what is the image in the middle part?")
image_code = torch.LongTensor([int(num) for num in code.strip().split()])
tgt_item = image_code + len(self.src_dict) - self.code_dict_size - self.num_bins
code_mask = torch.tensor([True])
conf = torch.tensor([2.0])
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
example = {
"id": image_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"code_mask": code_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
}
return [example]
def process_detection(self, index):
image_id, image, label = self.detection_dataset[index]
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
w, h = image.size
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
label_list = label.strip().split('&&')
for label in label_list:
x0, y0, x1, y1, cat_id, cat = label.strip().split(',', 5)
boxes_target["boxes"].append([float(x0), float(y0), float(x1), float(y1)])
boxes_target["labels"].append(cat)
boxes_target["area"].append((float(x1) - float(x0)) * (float(y1) - float(y0)))
boxes_target["boxes"] = torch.tensor(boxes_target["boxes"])
boxes_target["labels"] = np.array(boxes_target["labels"])
boxes_target["area"] = torch.tensor(boxes_target["area"])
patch_image, boxes_target = self.detection_transform(image, boxes_target)
patch_mask = torch.tensor([True])
code_mask = torch.tensor([False])
conf = torch.tensor([2.0])
quant_boxes = []
for i, box in enumerate(boxes_target["boxes"]):
quant_boxes.extend(["<bin_{}>".format(int((pos * (self.num_bins - 1)).round())) for pos in box[:4]])
quant_boxes.append(self.bpe.encode(' {}'.format(boxes_target["labels"][i])))
src_item = self.encode_text(' what are the objects in the image?')
tgt_item = self.encode_text(' '.join(quant_boxes), use_bpe=False)
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
example = {
"id": image_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"code_mask": code_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
}
return [example]
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch):
pair_samples = self.process_image_text_pair(index)
extra_samples = []
if self.split == 'train' and self.dataset.data_cnt % 8 == 0:
extra_samples += self.process_pure_text(0) if self.pure_text_dataset else []
extra_samples += self.process_pure_image(0) if self.pure_image_dataset else []
extra_samples += self.process_detection(0) if self.detection_dataset else []
return pair_samples, extra_samples
def word_starts(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def add_whole_word_mask(self, source, p):
is_word_start = self.word_starts(source)
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat(
[
lengths,
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
],
dim=0,
)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
)
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
)
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
)
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
result[noise_indices[:num_random]] = torch.randint(
low=4, high=len(self.tgt_dict)-self.code_dict_size-self.num_bins, size=(num_random,)
)
result[~noise_mask] = tokens
assert (result >= 0).all()
return result
def collater(self, samples, pad_to_length=None):
"""Merge samples of different tasks to form two mini-batches.
Args:
samples (List[Tuple]): samples to collate
Returns:
Tuple[dict]: two mini-batch containing the data of different tasks
"""
samples_v1 = [] # containing image-text pairs
samples_v2 = [] # containing detection data, text data and image data
for sample_tuple in samples:
samples_v1 += sample_tuple[0]
samples_v2 += sample_tuple[1]
if samples_v2 != []:
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1, res_v2
else:
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from io import BytesIO
import math
import logging
import random
import warnings
import numpy as np
import torch
import base64
from torchvision import transforms
from PIL import Image, ImageFile
from data import data_utils
from data.ofa_dataset import OFADataset
from utils.vision_helper import RandomAugment
import utils.transforms as T
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.data.audio.feature_transforms import *
from fairseq.data.audio.audio_utils import (
convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank
)
from pathlib import Path
import soundfile as sf
import librosa
import torchaudio
from typing import List
from pypinyin import pinyin, Style
from utils.text2phone import Text2Phone
from g2p_en import G2p
g2p = G2p()
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(
samples,
pad_idx,
eos_idx,
left_pad_source=False,
left_pad_target=False,
feature_only = True,
mask = False,
mask_prob = 0.0
):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx,
left_pad,
move_eos_to_beginning,
)
def _collate_frames(
frames: List[torch.Tensor]
):
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
def _collate_constraint_masks(
frames: List[torch.Tensor]
):
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))).bool()
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
id = np.array([s["id"] for s in samples])
src_tokens = merge("source", left_pad=left_pad_source)
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
fbank = None
fbank_length = None
fbank_masks = None
if samples[0].get("fbank", None) is not None:
fbank = _collate_frames([s["fbank"] for s in samples])
fbank_length = torch.tensor([s["fbank"].size(0) for s in samples], dtype=torch.long)
fbank_masks = torch.tensor([s["fbank_mask"] for s in samples])
audio_code_masks = None
if samples[0].get("audio_code_mask", None) is not None:
audio_code_masks = torch.cat([sample['audio_code_mask'] for sample in samples])
phone_items = None
phone_lengths = None
if samples[0].get("phone_item", None) is not None:
phone_items = merge("phone_item", left_pad=left_pad_source)
phone_lengths = torch.LongTensor([len(s["phone_item"]) for s in samples])
phone_masks = None
if samples[0].get("phone_mask", None) is not None:
phone_masks = torch.cat([sample['phone_mask'] for sample in samples])
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target", left_pad=left_pad_target)
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
else:
ntokens = src_lengths.sum().item()
constraint_masks = None
if samples[0].get("constraint_masks", None) is not None:
constraint_masks = _collate_constraint_masks([s["constraint_masks"] for s in samples])
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"fbank": fbank,
"fbank_length": fbank_length,
"fbank_masks": fbank_masks,
"phone_items": phone_items,
"phone_lengths": phone_lengths,
"phone_masks": phone_masks,
"audio_code_masks": audio_code_masks,
"prev_output_tokens": prev_output_tokens,
"encoder_features_only": feature_only,
"mask": mask,
"mask_prob": mask_prob
},
"target": target,
"ctc_outputs": phone_items,
"ctc_output_lengths": phone_lengths,
"constraint_masks": constraint_masks
}
return batch
class UnifyDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
phone_dict=None,
max_src_length=128,
max_tgt_length=30,
seed=7,
code_dict_size=8192,
audio_code_dict_size=30000,
num_bins=1000,
pure_text_dataset=None,
pure_audio_dataset=None,
speech_text_dataset=None,
config_yaml_path=None,
lang="zh",
text2phone_path=None,
train_stage=2,
n_frames_per_step=1,
sample_rate=16000,
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.phone_dict = phone_dict
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.seed = seed
self.code_dict_size = code_dict_size
self.audio_code_dict_size = audio_code_dict_size
self.num_bins = num_bins
self.pure_text_dataset = pure_text_dataset
self.pure_audio_dataset = pure_audio_dataset
self.speech_text_dataset = speech_text_dataset
self.epoch = 0
self.remove_pure_audio = self.pure_audio_dataset is None
self.remove_pure_text = self.pure_text_dataset is None
# config_yaml_path = Path(cfg.user_dir) / cfg.config_yaml)
self.data_cfg = S2TDataConfig(Path(config_yaml_path))
self.lang = lang
self.train_stage= train_stage
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
self.data_cfg.get_feature_transforms(split, split.startswith("train"))
)
self.n_frames_per_step = n_frames_per_step
self.sample_rate = sample_rate
self.blank_id = self.phone_dict.index("<blank>")
self.phone_mask_idx = self.phone_dict.index("<mask>")
self.text2phone_tokenizer = None
if text2phone_path is not None:
self.blank_id = self.phone_dict.index("<unk>")
self.text2phone_tokenizer = Text2Phone(text2phone_path)
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def process_pure_text(self, index):
if self.train_stage == 1:
speech_id, text = self.dataset[index]
else:
speech_id, text = self.pure_text_dataset[index]
conf = torch.tensor([1.0])
# fake input
fbank = torch.zeros((8, self.data_cfg.input_feat_per_channel))
fbank_mask = torch.tensor([False])
#
audio_code_mask = torch.tensor([False])
if self.lang == "en":
text = self.pre_caption(text, self.max_tgt_length)
elif self.lang == "zh":
text = self.pre_chinese(text, self.max_tgt_length)
else:
raise ValueError("lang must be en or zh")
phone = self.to_phone(text, self.lang)
phone_item = [int(x) for x in phone]
phone_item = torch.tensor(phone_item)
phone_item = self.add_noise_to_phone(phone_item, 0.3)
phone_mask = torch.tensor([True])
target = text
src_item = self.encode_text(" what does the phone say?")
tgt_item = self.encode_text(" {}".format(target))
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
constraint_masks = torch.stack([torch.arange(len(self.tgt_dict)) < len(
self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins for _ in
range(len(target_item))])
example = {
"id": speech_id,
"source": src_item,
"fbank": fbank,
"fbank_mask": fbank_mask,
"phone_item": phone_item,
"phone_mask": phone_mask,
"audio_code_mask": audio_code_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
"constraint_masks": constraint_masks
}
return [example]
def process_pure_audio(self, index):
if self.train_stage == 2:
speech_id, wav_data, code = self.dataset[index]
else:
speech_id, wav_data, code = self.pure_audio_dataset[index]
# fake input
phone_item = [6, 6, 6]
phone_item = torch.tensor(phone_item)
phone_mask = torch.tensor([False])
# speed
if self.split == "train":
speed = random.choice([0.9, 1.0, 1.1])
else:
speed = 1.0
wav, sr = sf.read(wav_data)
# spec_augmentation
fbank = self.prepare_fbank(torch.tensor([wav], dtype=torch.float32), sr, speed)
fbank_mask = torch.tensor([True])
audio_code_mask = torch.tensor([True])
if code is not None and len(code) > 0:
text = torch.LongTensor([int(num) for num in code.strip().split(",")])
tgt_item = text + len(self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins
else:
# fake
text = torch.LongTensor([1, 2, 3])
tgt_item = text
conf = torch.tensor([1.0])
# useless
src_item = self.encode_text(' what does the audio say?')
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
constraint_masks = torch.stack([torch.arange(len(self.tgt_dict)) >= len(
self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins for _ in
range(len(target_item))])
constraint_masks[:, :3] = True
example = {
"id": speech_id,
"source": src_item,
"fbank": fbank,
"fbank_mask": fbank_mask,
"phone_item": phone_item,
"phone_mask": phone_mask,
"audio_code_mask": audio_code_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
}
return [example]
def process_speech_text_pair(self, index, dataset=None):
if dataset is not None:
speech_id, wav_data, text = dataset[index]
elif self.train_stage == 2:
speech_id, wav_data, text = self.speech_text_dataset[index]
else:
speech_id, wav_data, text = self.dataset[index]
conf = torch.tensor([1.0])
audio_code_mask = torch.tensor([False])
# speed
if self.split == "train":
speed = random.choice([0.9, 1.0, 1.1])
else:
speed = 1.0
# wav, sr = sf.read(wav_data)
wav, sr = librosa.load(wav_data, self.sample_rate)
# spec_augmentation
fbank = self.prepare_fbank(torch.tensor([wav], dtype=torch.float32), sr, speed, speech_id)
fbank_mask = torch.tensor([True])
if self.lang == "en":
text = self.pre_caption(text, self.max_tgt_length)
elif self.lang == "zh":
text = self.pre_chinese(text, self.max_tgt_length)
else:
raise ValueError("lang must be en or zh")
target = text
phone_item = self.to_phone(text, self.lang)-3
phone_mask = torch.tensor([False])
src_item = self.encode_text(" what does the audio say?")
tgt_item = self.encode_text(" {}".format(target))
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
constraint_masks = torch.stack([torch.arange(len(self.tgt_dict)) < len(
self.tgt_dict) - self.audio_code_dict_size - self.code_dict_size - self.num_bins for _ in
range(len(target_item))])
example = {
"id": speech_id,
"source": src_item,
"fbank": fbank,
"fbank_mask": fbank_mask,
"phone_item": phone_item,
"phone_mask": phone_mask,
"audio_code_mask": audio_code_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"conf": conf,
"constraint_masks": constraint_masks
}
return [example]
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch):
if self.train_stage == 1:
extra_samples = []
if self.dataset is not None:
extra_samples += self.process_pure_text(index) if not self.remove_pure_text else []
return extra_samples, [], []
elif self.train_stage == 2:
pair_examples = []
audio_examples = []
extra_samples = []
if self.split == 'train':
if self.dataset is not None:
audio_examples += self.process_pure_audio(index) if not self.remove_pure_audio else []
if self.speech_text_dataset is not None and self.dataset.data_cnt % 4 == 0:
pair_examples += self.process_speech_text_pair(index)
if self.pure_text_dataset is not None and self.dataset.data_cnt % 2 == 0:
extra_samples += self.process_pure_text(index) if not self.remove_pure_text else []
else:
if self.dataset is not None:
pair_examples += self.process_speech_text_pair(index, self.dataset)
return pair_examples, extra_samples, audio_examples
else:
pair_examples = []
extra_samples = []
if self.split == 'train':
if self.dataset is not None:
pair_examples += self.process_speech_text_pair(index)
if self.pure_text_dataset is not None and self.dataset.data_cnt % 2 == 0:
extra_samples += self.process_pure_text(index) if not self.remove_pure_text else []
else:
if self.dataset is not None:
pair_examples += self.process_speech_text_pair(index, self.dataset)
return pair_examples, extra_samples, []
def to_phone(self, text, lang):
if lang == "en":
phone_result = None
try:
phone_result = " ".join(p for p in g2p(text))
except Exception as e:
print(e, text)
return self.encode_phone(phone_result)
elif lang == "zh":
if self.text2phone_tokenizer is not None:
final_phone = self.text2phone_tokenizer.trans(text)
return self.encode_phone(final_phone)
else:
shengmu = pinyin(text, style=Style.INITIALS, strict=False)
yunmu = pinyin(text, style=Style.FINALS_TONE3, strict=False)
assert len(shengmu) == len(yunmu)
final_phone = []
for s, y in zip(shengmu, yunmu):
if s[0] == y[0] or s[0] == "":
final_phone.append(y[0])
else:
final_phone.append(s[0] + " " + y[0])
return self.encode_phone(" ".join(final_phone))
def encode_phone(self, phone_item):
tokens = self.phone_dict.encode_line(
line=phone_item, add_if_not_exist=False, append_eos=False).long()
return tokens
def add_noise_to_phone(self, phone, p, random_p=0.1):
num_to_mask = int(math.ceil(phone.size(0) * p))
indices = torch.randperm(phone.size(0))[:num_to_mask]
mask_random = torch.FloatTensor(num_to_mask).uniform_() < random_p
phone[indices] = self.phone_mask_idx
if mask_random.sum() > 0:
phone[indices[mask_random]] = torch.randint(
4, self.phone_mask_idx, size=(mask_random.sum(),)
)
return phone
def prepare_fbank(self, waveform, sample_rate, speed, speech_id=None):
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate,
[['speed', str(speed)], ['rate', str(sample_rate)]])
_waveform, _ = convert_waveform(waveform, sample_rate, to_mono=True, normalize_volume=True)
# Kaldi compliance: 16-bit signed integers
_waveform = _waveform * (2 ** 15)
_waveform = _waveform.numpy()
fbank = _get_kaldi_fbank(_waveform, sample_rate, 80)
if fbank is None:
fbank = _get_torchaudio_fbank(_waveform, sample_rate, 80)
if fbank is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable fbank feature extraction"
)
if self.feature_transforms is not None:
fbank = self.feature_transforms(fbank)
fbank = torch.from_numpy(fbank).float()
fbank = self.pack_frames(fbank)
return fbank
def pack_frames(self, feature: torch.Tensor):
if self.n_frames_per_step == 1:
return feature
n_packed_frames = feature.shape[0] // self.n_frames_per_step
feature = feature[: self.n_frames_per_step * n_packed_frames]
return feature.reshape(n_packed_frames, -1)
def collater(self, samples, pad_to_length=None):
"""Merge samples of different tasks to form two mini-batches.
Args:
samples (List[Tuple]): samples to collate
Returns:
Tuple[dict]: two mini-batch containing the data of different tasks
"""
samples_v1 = [] # containing phone-text pairs at stage-1, containing speech-text pairs at stage-2
samples_v2 = [] # containing phone-text pairs
samples_v3 = [] # containing pure_audio_pairs
for sample_tuple in samples:
samples_v1 += sample_tuple[0]
samples_v2 += sample_tuple[1]
if len(sample_tuple) > 2:
samples_v3 += sample_tuple[2]
if samples_v1 == []:
if self.train_stage == 1:
samples_v1 += self.process_pure_text(0)
else:
samples_v1 += self.process_speech_text_pair(0)
mask = False
mask_prob = None
if self.split == "train" and self.train_stage != 1:
mask = True
mask_prob = 0.3
res_v1 = collate(
samples_v1,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
feature_only=True,
mask=mask,
mask_prob=mask_prob
)
if self.split == 'train' and self.train_stage != 1:
if samples_v2 == []:
if self.pure_text_dataset is not None:
samples_v2 += self.process_pure_text(0) if not self.remove_pure_text else []
res_v2 = collate(
samples_v2,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos
)
if samples_v3 == []:
if self.pure_audio_dataset is not None:
samples_v3 += self.process_pure_audio(0) if not self.remove_pure_audio else []
else:
return res_v1, res_v2
res_v3 = collate(
samples_v3,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
feature_only=False,
mask=True
)
return res_v1, res_v2, res_v3
else:
return res_v1
# Datasets
We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
## Pretraining
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/pretrain_data/pretrain_data_examples.zip"> A small subset of the pretraining data </a>
The pretraining datasets used in OFA are all publicly available. Here we provide the public links to these data, it is recommended that you download the data from the links first, and then process the downloaded dataset into a similar format as the examples we provided.
- _CC12M_: https://github.com/google-research-datasets/conceptual-12m
- _CC3M_: https://github.com/google-research-datasets/conceptual-captions
- _SBU_: https://www.cs.virginia.edu/~vicente/sbucaptions
- _COCO_: https://cocodataset.org/#home
- _VG_: https://visualgenome.org/
- _VQAv2_: https://visualqa.org/
- _GQA_: https://cs.stanford.edu/people/dorarad/gqa/about.html
- _RefCOCO_/_RefCOCO+_/RefCOCOg: https://github.com/lichengunc/refer
- _OpenImages_: https://storage.googleapis.com/openimages/web/index.html
- _Object365_: https://www.objects365.org/overview.html
- _YFCC100M (subset)_: https://github.com/openai/CLIP/blob/main/data/yfcc100m.md
- _ImageNet-21K_: https://image-net.org/index.php
- _Pile_: https://pile.eleuther.ai
## Vision & Language Tasks
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/vqa_data/vqa_data.zip"> Dataset for VQAv2 </a> (we have also provided chunked parts of the dataset files for more convenient downloading, please refer to <a href="https://github.com/OFA-Sys/OFA/issues/68#issuecomment-1096837349">issue #68</a>)
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/snli_ve_data/snli_ve_data.zip"> Dataset for SNLI-VE </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen.zip"> Dataset for Text-to-Image Genearion </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen_origin_id.zip"> Dataset for Text-to-Image Genearion (with original id) </a>
## Vision Tasks
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/imagenet_1k_data/imagenet_1k_data.zip"> Dataset for ImageNet-1K </a>
## Language Tasks
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/cola_data.zip"> Dataset for COLA </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mnli_data.zip"> Dataset for MNLI </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mrpc_data.zip"> Dataset for MRPC </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qnli_data.zip"> Dataset for QNLI </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qqp_data.zip"> Dataset for QQP </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/rte_data.zip"> Dataset for RTE </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/sst2_data.zip"> Dataset for SST2 </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/gigaword_data/gigaword_data.zip"> Dataset for Gigaword </a>
## OFA Raw Images for Case Study
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/ofa_images.zip"> OFA Raw Images for Case Study </a>
Here we provide raw image files for visualization examples in OFA.
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
RUN source /opt/dtk-23.04/env.sh
# 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -r requirements.txt
opencv-python==4.5.5.64
timm
ftfy==6.0.3
tensorboardX==2.4.1
pycocotools==2.0.7
pycocoevalcap==1.2
pytorch_lightning
einops
datasets
rouge_score
soundfile
editdistance
librosa
python-Levenshtein
zhconv
pypinyin==0.47.1
g2p_en
tensorboard
protobuf==3.20.2
numpy==1.23.5
docker run -it -v /parastor/home/chenzk/:/home/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name ofa 2bb84d403fac bash
#!/usr/bin/env python3 -u
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import os
import sys
import numpy as np
import torch
from fairseq import distributed_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.utils import reset_logging
from omegaconf import DictConfig
from utils import checkpoint_utils
from utils.eval_utils import eval_step, merge_results
from utils.zero_shot_utils import zero_shot_step
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("ofa.evaluate")
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
def main(cfg: DictConfig, **kwargs):
utils.import_user_module(cfg.common)
reset_logging()
logger.info(cfg)
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
if use_cuda:
torch.cuda.set_device(cfg.distributed_training.device_id)
# Load ensemble
overrides = eval(cfg.common_eval.model_overrides)
# Deal with beam-search / all-candidate VQA eval
if cfg.task._name == "vqa_gen":
overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
if kwargs["zero_shot"]:
task = tasks.setup_task(cfg.task)
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
else:
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# loading the dataset should happen after the checkpoint has been loaded
# so we can give it the saved task config
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
if cfg.generation.lm_path is not None:
overrides["data"] = cfg.task.data
try:
lms, _ = checkpoint_utils.load_model_ensemble(
[cfg.generation.lm_path], arg_overrides=overrides, task=None
)
except:
logger.warning(
f"Failed to load language model! Please make sure that the language model dict is the same "
f"as target dict and is located in the data dir ({cfg.task.data})"
)
raise
assert len(lms) == 1
else:
lms = [None]
# Move models to GPU
for model, ckpt_path in zip(
models, utils.split_paths(
cfg.common_eval.path)):
if kwargs['ema_eval']:
logger.info("loading EMA weights from {}".format(ckpt_path))
model.load_state_dict(
checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
model.eval()
if use_fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[m.max_positions() for m in models]
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=(
"tqdm" if not cfg.common.no_progress_bar else "simple"),
)
# Initialize generator
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
generator = task.build_generator(models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
results = []
score_sum = torch.FloatTensor([0]).cuda()
score_cnt = torch.FloatTensor([0]).cuda()
for sample in progress:
if "net_input" not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(
apply_half, sample) if cfg.common.fp16 else sample
with torch.no_grad():
if kwargs["zero_shot"]:
result, scores = zero_shot_step(
task, generator, models, sample)
else:
result, scores = eval_step(
task, generator, models, sample, **kwargs)
results += result
if scores and isinstance(scores[0], tuple):
score_sum += sum([s[0] for s in scores])
score_cnt += sum([s[1] for s in scores])
else:
score_sum += sum(scores) if scores is not None else 0
score_cnt += len(scores) if scores is not None else 0
progress.log({"sentences": sample["nsentences"]})
merge_results(task, cfg, logger, score_cnt, score_sum, results)
def cli_main():
parser = options.get_generation_parser()
parser.add_argument(
"--ema-eval",
action='store_true',
help="Use EMA weights to make evaluation.")
parser.add_argument(
"--beam-search-vqa-eval",
action='store_true',
help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
parser.add_argument("--zero-shot", action='store_true')
args = options.parse_args_and_arch(parser)
cfg = convert_namespace_to_omegaconf(args)
distributed_utils.call_main(
cfg,
main,
ema_eval=args.ema_eval,
beam_search_vqa_eval=args.beam_search_vqa_eval,
zero_shot=args.zero_shot,
)
if __name__ == "__main__":
cli_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment