Unverified Commit 119bfd15 authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Refactor] Move fix_id_list to Retriever (#442)

* [Refactor] Move fix_id_list to Retriever

* update

* move to base

* fix
parent 767c12a6
......@@ -25,9 +25,6 @@ class BaseInferencer:
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
api_name (:obj:`str`, optional): Name of API service.
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
determined by :obj:`api_name`.
"""
model = None
......@@ -38,8 +35,15 @@ class BaseInferencer:
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
**kwargs,
) -> None:
if fix_id_list:
raise ValueError('Passing fix_id_list to Inferencer is no longer '
'allowed. Please pass it to FixKRetriever '
'instead.')
self.model = model
self.max_seq_len = max_seq_len
......
......@@ -54,7 +54,6 @@ class CLPInferencer(BaseInferencer):
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
single_token: bool = True,
**kwargs) -> None:
super().__init__(
......@@ -66,7 +65,6 @@ class CLPInferencer(BaseInferencer):
**kwargs,
)
self.fix_id_list = fix_id_list
# TODO: support multiple token
assert single_token, 'Only support single token choice currently.'
self.single_token = single_token
......@@ -103,9 +101,6 @@ class CLPInferencer(BaseInferencer):
raise ValueError(err_msg)
# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate in-context examples for testing inputs
......
......@@ -51,7 +51,6 @@ class GenInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
......@@ -64,7 +63,6 @@ class GenInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None:
save_every = 1
......@@ -85,9 +83,6 @@ class GenInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
......@@ -220,9 +215,6 @@ class GLMChoiceInferencer(GenInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
......
......@@ -41,7 +41,6 @@ class PPLInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
labels: Optional[List] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
......@@ -53,7 +52,6 @@ class PPLInferencer(BaseInferencer):
)
self.labels = labels
self.fix_id_list = fix_id_list
def inference(self,
retriever: BaseRetriever,
......@@ -75,9 +73,6 @@ class PPLInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Get labels of all the classes
......
......@@ -52,7 +52,6 @@ class SCInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
sc_size: Optional[int] = 1,
infer_type: Optional[str] = '',
generation_kwargs: dict = {},
......@@ -69,7 +68,6 @@ class SCInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token
self.generation_kwargs = generation_kwargs
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
self.sc_size = sc_size
if self.model.is_api and save_every is None:
......@@ -91,9 +89,6 @@ class SCInferencer(BaseInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
......
......@@ -46,7 +46,6 @@ class ToTInferencer(GenInferencer):
`save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
fix_id_list (:obj:`List[int]`, optional): List of indices to fix
naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of
ToT + BFS.
prompt_wrapper (:obj:`dict`): wrapper for prompts
......@@ -76,7 +75,6 @@ class ToTInferencer(GenInferencer):
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
naive_run: bool = False,
prompt_wrapper: dict = {},
prompt_sample: str = 'standard',
......@@ -97,7 +95,6 @@ class ToTInferencer(GenInferencer):
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
save_every=save_every,
fix_id_list=fix_id_list,
sc_size=n_evaluate_sample,
**kwargs,
)
......@@ -319,9 +316,6 @@ class ToTInferencer(GenInferencer):
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
......
......@@ -19,6 +19,8 @@ class FixKRetriever(BaseRetriever):
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
fix_id_list (List[int]): List of in-context example indices for every
test prompts.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '\n'.
......@@ -31,22 +33,19 @@ class FixKRetriever(BaseRetriever):
def __init__(self,
dataset,
fix_id_list: List[int],
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.fix_id_list = fix_id_list
def retrieve(self, id_list: List[int]):
"""Retrieve the in-context example index for each test example.
Args:
id_list (List[int]): List of in-context example indices for every
test prompts.
"""
def retrieve(self):
"""Retrieve the in-context example index for each test example."""
num_idx = len(self.index_ds)
for idx in id_list:
for idx in self.fix_id_list:
assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
rtr_idx_list = []
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
rtr_idx_list.append(id_list)
rtr_idx_list.append(self.fix_id_list)
return rtr_idx_list
......@@ -56,6 +56,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()
......
......@@ -61,7 +61,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
infer_cfg = dataset_cfg.get('infer_cfg')
fix_id_list = infer_cfg.inferencer.get('fix_id_list', [])
dataset = build_dataset_from_cfg(dataset_cfg)
ice_template = None
......@@ -76,9 +75,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
infer_cfg['retriever']['dataset'] = dataset
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])
if fix_id_list:
ice_idx_list = retriever.retrieve(fix_id_list)
else:
ice_idx_list = retriever.retrieve()
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
......
......@@ -45,6 +45,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment