Unverified Commit 119bfd15 authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Refactor] Move fix_id_list to Retriever (#442)

* [Refactor] Move fix_id_list to Retriever

* update

* move to base

* fix
parent 767c12a6
......@@ -23,8 +23,8 @@ CoLA_infer_cfg = dict(
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[17, 18, 19, 20, 21]))
retriever=dict(type=FixKRetriever, fix_id_list=[17, 18, 19, 20, 21]),
inferencer=dict(type=PPLInferencer))
CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
......
......@@ -22,8 +22,8 @@ QQP_infer_cfg = dict(
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]))
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer))
QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
......
......@@ -161,8 +161,8 @@ for _split in ["val", "test"]:
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
ceval_eval_cfg = dict(
......
......@@ -161,8 +161,8 @@ for _split in ["val"]:
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
ceval_eval_cfg = dict(
......
......@@ -163,8 +163,8 @@ for _split in ["val"]:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
......
......@@ -163,8 +163,8 @@ for _split in ["val", "test"]:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
......
......@@ -28,8 +28,8 @@ cmb_infer_cfg = dict(
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
cmb_datasets.append(
......
......@@ -96,8 +96,8 @@ for _name in cmmlu_all_sets:
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
cmmlu_eval_cfg = dict(
......
......@@ -98,8 +98,8 @@ for _name in cmmlu_all_sets:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
......
......@@ -29,8 +29,8 @@ mmlu_infer_cfg = dict(
dict(role='BOT', prompt='{target}\n')
])),
prompt_template=mmlu_prompt_template,
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]))
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer))
mmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
......
......@@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(
......
......@@ -87,8 +87,8 @@ for _name in mmlu_all_sets:
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:",
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(
......
......@@ -102,8 +102,8 @@ for _name in mmlu_all_sets:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)
mmlu_eval_cfg = dict(
......
......@@ -93,8 +93,8 @@ for _name in mmlu_all_sets:
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)
mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
......
......@@ -44,8 +44,8 @@ for k in [0, 1, 5]:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50),
)
nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")
......
......@@ -45,8 +45,8 @@ for k in [0, 1, 5]:
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50),
)
triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT")
......
......@@ -34,8 +34,8 @@ infer_cfg = dict(
template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>"
),
retriever=dict(type=FixKRetriever), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # Method used to generate predictions.
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer), # Method used to generate predictions.
)
```
......
......@@ -34,8 +34,8 @@ infer_cfg=dict(
template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>"
),
retriever=dict(type=FixKRetriever), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # 使用何种方式推理得到 prediction
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer), # 使用何种方式推理得到 prediction
)
```
......
......@@ -55,10 +55,7 @@ class AgentInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()
# Create tmp json file for saving intermediate results and future
# resuming
......
......@@ -59,7 +59,6 @@ class AttackInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
dataset_cfg: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
......@@ -78,7 +77,6 @@ class AttackInferencer(BaseInferencer):
self.output_column = dataset_cfg['reader_cfg']['output_column']
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None:
save_every = 1
......@@ -94,10 +92,7 @@ class AttackInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in self.retriever.__class__.__name__:
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = self.retriever.retrieve()
ice_idx_list = self.retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment