assertprev_query_id==sample["query_id"],f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
if"history"insample:
assertlen(sample["history"])==len(teacher_scores),f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
else:
assertlen(sample["pos"]+sample["neg"])==len(teacher_scores),f"Found incompatible key number from data ({len(sample['pos']+sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
# accumulate scores of different keys for the same query
# log likelihood
teacher_scores.append(-score)
prev_query_id=query_id
# NOTE: the last line
sample=json.loads(f.readline().strip())
assertprev_query_id==sample["query_id"],f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
if"history"insample:
assertlen(sample["history"])==len(teacher_scores),f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
else:
assertlen(sample["pos"]+sample["neg"])==len(teacher_scores),f"Found incompatible key number from data ({len(sample['pos']+sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
assertos.path.exists(src_dir),f"Make sure the encoder path {src_dir} is valid on disk!"
assert"decoder"notinpooling_method,f"Pooling method 'decode' cannot be saved as sentence_transformers because it uses the decoder stack to produce sentence embedding."
ifdest_dirisNone:
dest_dir=src_dir
print(f"loading model from {src_dir} and saving the sentence_transformer model at {dest_dir}...")
# Find the start of target. All retrieval operation starts from the preceeding chunk to the target
is_valid=(labels!=-100).float()
target_start_index=is_valid.argmax(-1)
assert(target_start_index==target_start_index[0]).all(),f"Make sure all targets in the batch starts from the same token index!"
target_start_index=target_start_index[0].item()
asserttarget_start_index%self.chunk_size==0,f"Make sure the target_length ({inputs_length} - {target_start_index} = {inputs_length-target_start_index}) is divisible by chunk_size ({self.chunk_size})!"
assertn_window_chunk>=(n_target_chunk+1+2*self.key_num),f"Make sure there are at least k * 2 + 1 + n_target_chunk = {self.key_num*2+1+n_target_chunk} chunks (found {context_window_size} / {self.chunk_size} = {n_window_chunk}) that can be replaced with retrieved contents!"
# these tokens will be directly concatenated with retrieved chunks
metadata={'help':'Default path to save language models.'}
)
dataset_cache_dir:Optional[str]=field(
default=None,
metadata={'help':'Default path to save huggingface datasets.'}
)
data_root:str=field(
default="/data/llm-embedder",
metadata={'help':'The base directory storing all data used for training and evaluation. If specified, make sure all train_data, eval_data, and corpus are path relative to data_root!'},
)
train_data:Optional[List[str]]=field(
default=None,
metadata={'help':'Training json file or glob to match a list of files.'},
)
eval_data:Optional[str]=field(
default=None,
metadata={'help':'Evaluation json file.'},
)
corpus:str=field(
default=None,
metadata={'help':'Corpus jsonl file.'}
)
key_template:str=field(
default="{title} {text}",
metadata={'help':'How to concatenate columns in the corpus to form one key?'}
)
metrics:List[str]=field(
default_factory=lambda:["mrr","recall","ndcg"],
metadata={'help':'List of metrics'}
)
cutoffs:List[int]=field(
default_factory=lambda:[1,5,10,100],
metadata={'help':'Cutoffs to evaluate retrieval metrics.'}
)
filter_answers:bool=field(
default=False,
metadata={'help':'Remove negatives that contain the desired answer when collating negatives?'}
)
max_neg_num:int=field(
default=100,
metadata={'help':'Maximum negative number to mine.'}
logger.info(f"saving {max_neg_num} negatives to {save_path}...")
withopen(eval_data)asf,open(save_path,"w")asg:
forlineintqdm(f,desc="Collating Negatives"):
item=json.loads(line)
query_id=item["query_id"]
# NOTE: some queries may not correspond to any negatives (especially in case of BM25), just skip them
ifquery_idnotinquery_id_2_pred:
continue
pred=query_id_2_pred[query_id]
if"pos"initem:
pos=set(item["pos"])
else:
# sometime we do not have pre-defined pos, instead, the pos will be selected from neg based on teacher scores
pos=[]
# first filter out positive documents
if"pos_index"initem:
pos_index=item["pos_index"]
pred=[iforiinpredifi!=pos_index]
neg=corpus[pred]["content"]
# remove key that is the same as pos
# NOTE: here we do not use pos_index to distinguish pos and neg, because different pos_index may correpond to the same content due to duplication in the corpus