Unverified Commit 119bfd15 authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Refactor] Move fix_id_list to Retriever (#442)

* [Refactor] Move fix_id_list to Retriever

* update

* move to base

* fix
parent 767c12a6
...@@ -23,8 +23,8 @@ CoLA_infer_cfg = dict( ...@@ -23,8 +23,8 @@ CoLA_infer_cfg = dict(
}, },
ice_token='</E>', ice_token='</E>',
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[17, 18, 19, 20, 21]),
inferencer=dict(type=PPLInferencer, fix_id_list=[17, 18, 19, 20, 21])) inferencer=dict(type=PPLInferencer))
CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
......
...@@ -22,8 +22,8 @@ QQP_infer_cfg = dict( ...@@ -22,8 +22,8 @@ QQP_infer_cfg = dict(
}, },
ice_token='</E>', ice_token='</E>',
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4])) inferencer=dict(type=PPLInferencer))
QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
......
...@@ -161,8 +161,8 @@ for _split in ["val", "test"]: ...@@ -161,8 +161,8 @@ for _split in ["val", "test"]:
]), ]),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
ceval_eval_cfg = dict( ceval_eval_cfg = dict(
......
...@@ -161,8 +161,8 @@ for _split in ["val"]: ...@@ -161,8 +161,8 @@ for _split in ["val"]:
]), ]),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
ceval_eval_cfg = dict( ceval_eval_cfg = dict(
......
...@@ -163,8 +163,8 @@ for _split in ["val"]: ...@@ -163,8 +163,8 @@ for _split in ["val"]:
}, },
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=PPLInferencer),
) )
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
......
...@@ -163,8 +163,8 @@ for _split in ["val", "test"]: ...@@ -163,8 +163,8 @@ for _split in ["val", "test"]:
}, },
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=PPLInferencer),
) )
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
......
...@@ -28,8 +28,8 @@ cmb_infer_cfg = dict( ...@@ -28,8 +28,8 @@ cmb_infer_cfg = dict(
), ),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
cmb_datasets.append( cmb_datasets.append(
......
...@@ -96,8 +96,8 @@ for _name in cmmlu_all_sets: ...@@ -96,8 +96,8 @@ for _name in cmmlu_all_sets:
]), ]),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
cmmlu_eval_cfg = dict( cmmlu_eval_cfg = dict(
......
...@@ -98,8 +98,8 @@ for _name in cmmlu_all_sets: ...@@ -98,8 +98,8 @@ for _name in cmmlu_all_sets:
}, },
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=PPLInferencer),
) )
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
......
...@@ -29,8 +29,8 @@ mmlu_infer_cfg = dict( ...@@ -29,8 +29,8 @@ mmlu_infer_cfg = dict(
dict(role='BOT', prompt='{target}\n') dict(role='BOT', prompt='{target}\n')
])), ])),
prompt_template=mmlu_prompt_template, prompt_template=mmlu_prompt_template,
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4])) inferencer=dict(type=GenInferencer))
mmlu_eval_cfg = dict( mmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=AccEvaluator),
......
...@@ -102,8 +102,8 @@ for _name in mmlu_all_sets: ...@@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
), ),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
mmlu_eval_cfg = dict( mmlu_eval_cfg = dict(
......
...@@ -87,8 +87,8 @@ for _name in mmlu_all_sets: ...@@ -87,8 +87,8 @@ for _name in mmlu_all_sets:
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:", f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:",
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
mmlu_eval_cfg = dict( mmlu_eval_cfg = dict(
......
...@@ -102,8 +102,8 @@ for _name in mmlu_all_sets: ...@@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
), ),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=GenInferencer),
) )
mmlu_eval_cfg = dict( mmlu_eval_cfg = dict(
......
...@@ -93,8 +93,8 @@ for _name in mmlu_all_sets: ...@@ -93,8 +93,8 @@ for _name in mmlu_all_sets:
}, },
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), inferencer=dict(type=PPLInferencer),
) )
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
......
...@@ -44,8 +44,8 @@ for k in [0, 1, 5]: ...@@ -44,8 +44,8 @@ for k in [0, 1, 5]:
), ),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))), inferencer=dict(type=GenInferencer, max_out_len=50),
) )
nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT") nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")
......
...@@ -45,8 +45,8 @@ for k in [0, 1, 5]: ...@@ -45,8 +45,8 @@ for k in [0, 1, 5]:
), ),
ice_token="</E>", ice_token="</E>",
), ),
retriever=dict(type=FixKRetriever), retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))), inferencer=dict(type=GenInferencer, max_out_len=50),
) )
triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT") triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT")
......
...@@ -34,8 +34,8 @@ infer_cfg = dict( ...@@ -34,8 +34,8 @@ infer_cfg = dict(
template='Solve the following questions.\n</E>{question}\n{answer}', template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>" ice_token="</E>"
), ),
retriever=dict(type=FixKRetriever), # Definition of how to retrieve in-context examples. retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # Method used to generate predictions. inferencer=dict(type=GenInferencer), # Method used to generate predictions.
) )
``` ```
......
...@@ -34,8 +34,8 @@ infer_cfg=dict( ...@@ -34,8 +34,8 @@ infer_cfg=dict(
template='Solve the following questions.\n</E>{question}\n{answer}', template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>" ice_token="</E>"
), ),
retriever=dict(type=FixKRetriever), # 定义 in context example 的获取方式 retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # 使用何种方式推理得到 prediction inferencer=dict(type=GenInferencer), # 使用何种方式推理得到 prediction
) )
``` ```
......
...@@ -55,9 +55,6 @@ class AgentInferencer(BaseInferencer): ...@@ -55,9 +55,6 @@ class AgentInferencer(BaseInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve() ice_idx_list = retriever.retrieve()
# Create tmp json file for saving intermediate results and future # Create tmp json file for saving intermediate results and future
......
...@@ -59,7 +59,6 @@ class AttackInferencer(BaseInferencer): ...@@ -59,7 +59,6 @@ class AttackInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None, save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
dataset_cfg: Optional[List[int]] = None, dataset_cfg: Optional[List[int]] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__( super().__init__(
...@@ -78,7 +77,6 @@ class AttackInferencer(BaseInferencer): ...@@ -78,7 +77,6 @@ class AttackInferencer(BaseInferencer):
self.output_column = dataset_cfg['reader_cfg']['output_column'] self.output_column = dataset_cfg['reader_cfg']['output_column']
self.gen_field_replace_token = gen_field_replace_token self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None: if self.model.is_api and save_every is None:
save_every = 1 save_every = 1
...@@ -94,9 +92,6 @@ class AttackInferencer(BaseInferencer): ...@@ -94,9 +92,6 @@ class AttackInferencer(BaseInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if 'Fix' in self.retriever.__class__.__name__:
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = self.retriever.retrieve() ice_idx_list = self.retriever.retrieve()
# 3. Generate prompts for testing input # 3. Generate prompts for testing input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment