Unverified Commit 119bfd15 authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Refactor] Move fix_id_list to Retriever (#442)

* [Refactor] Move fix_id_list to Retriever

* update

* move to base

* fix
parent 767c12a6
...@@ -25,9 +25,6 @@ class BaseInferencer: ...@@ -25,9 +25,6 @@ class BaseInferencer:
`JSON` file. `JSON` file.
output_json_filename (:obj:`str`, optional): File name for output output_json_filename (:obj:`str`, optional): File name for output
`JSON` file. `JSON` file.
api_name (:obj:`str`, optional): Name of API service.
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
determined by :obj:`api_name`.
""" """
model = None model = None
...@@ -38,8 +35,15 @@ class BaseInferencer: ...@@ -38,8 +35,15 @@ class BaseInferencer:
batch_size: Optional[int] = 1, batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
if fix_id_list:
raise ValueError('Passing fix_id_list to Inferencer is no longer '
'allowed. Please pass it to FixKRetriever '
'instead.')
self.model = model self.model = model
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
......
...@@ -54,7 +54,6 @@ class CLPInferencer(BaseInferencer): ...@@ -54,7 +54,6 @@ class CLPInferencer(BaseInferencer):
batch_size: Optional[int] = 1, batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
single_token: bool = True, single_token: bool = True,
**kwargs) -> None: **kwargs) -> None:
super().__init__( super().__init__(
...@@ -66,7 +65,6 @@ class CLPInferencer(BaseInferencer): ...@@ -66,7 +65,6 @@ class CLPInferencer(BaseInferencer):
**kwargs, **kwargs,
) )
self.fix_id_list = fix_id_list
# TODO: support multiple token # TODO: support multiple token
assert single_token, 'Only support single token choice currently.' assert single_token, 'Only support single token choice currently.'
self.single_token = single_token self.single_token = single_token
...@@ -103,10 +101,7 @@ class CLPInferencer(BaseInferencer): ...@@ -103,10 +101,7 @@ class CLPInferencer(BaseInferencer):
raise ValueError(err_msg) raise ValueError(err_msg)
# 2. Get results of retrieval process # 2. Get results of retrieval process
if self.fix_id_list: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate in-context examples for testing inputs # 3. Generate in-context examples for testing inputs
for idx in range(len(ice_idx_list)): for idx in range(len(ice_idx_list)):
......
...@@ -51,7 +51,6 @@ class GenInferencer(BaseInferencer): ...@@ -51,7 +51,6 @@ class GenInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None, save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__( super().__init__(
model=model, model=model,
...@@ -64,7 +63,6 @@ class GenInferencer(BaseInferencer): ...@@ -64,7 +63,6 @@ class GenInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None: if self.model.is_api and save_every is None:
save_every = 1 save_every = 1
...@@ -85,10 +83,7 @@ class GenInferencer(BaseInferencer): ...@@ -85,10 +83,7 @@ class GenInferencer(BaseInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input # 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices( prompt_list = self.get_generation_prompt_list_from_retriever_indices(
...@@ -220,10 +215,7 @@ class GLMChoiceInferencer(GenInferencer): ...@@ -220,10 +215,7 @@ class GLMChoiceInferencer(GenInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input # 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices( prompt_list = self.get_generation_prompt_list_from_retriever_indices(
......
...@@ -41,7 +41,6 @@ class PPLInferencer(BaseInferencer): ...@@ -41,7 +41,6 @@ class PPLInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
labels: Optional[List] = None, labels: Optional[List] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__( super().__init__(
model=model, model=model,
...@@ -53,7 +52,6 @@ class PPLInferencer(BaseInferencer): ...@@ -53,7 +52,6 @@ class PPLInferencer(BaseInferencer):
) )
self.labels = labels self.labels = labels
self.fix_id_list = fix_id_list
def inference(self, def inference(self,
retriever: BaseRetriever, retriever: BaseRetriever,
...@@ -75,10 +73,7 @@ class PPLInferencer(BaseInferencer): ...@@ -75,10 +73,7 @@ class PPLInferencer(BaseInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if self.fix_id_list: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Get labels of all the classes # 3. Get labels of all the classes
if self.labels is None: if self.labels is None:
......
...@@ -52,7 +52,6 @@ class SCInferencer(BaseInferencer): ...@@ -52,7 +52,6 @@ class SCInferencer(BaseInferencer):
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None, save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
sc_size: Optional[int] = 1, sc_size: Optional[int] = 1,
infer_type: Optional[str] = '', infer_type: Optional[str] = '',
generation_kwargs: dict = {}, generation_kwargs: dict = {},
...@@ -69,7 +68,6 @@ class SCInferencer(BaseInferencer): ...@@ -69,7 +68,6 @@ class SCInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token self.gen_field_replace_token = gen_field_replace_token
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.max_out_len = max_out_len self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
self.sc_size = sc_size self.sc_size = sc_size
if self.model.is_api and save_every is None: if self.model.is_api and save_every is None:
...@@ -91,10 +89,7 @@ class SCInferencer(BaseInferencer): ...@@ -91,10 +89,7 @@ class SCInferencer(BaseInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input # 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices( prompt_list = self.get_generation_prompt_list_from_retriever_indices(
......
...@@ -46,7 +46,6 @@ class ToTInferencer(GenInferencer): ...@@ -46,7 +46,6 @@ class ToTInferencer(GenInferencer):
`save_every` epochs. `save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method. :obj:`model.generate()` method.
fix_id_list (:obj:`List[int]`, optional): List of indices to fix
naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of
ToT + BFS. ToT + BFS.
prompt_wrapper (:obj:`dict`): wrapper for prompts prompt_wrapper (:obj:`dict`): wrapper for prompts
...@@ -76,7 +75,6 @@ class ToTInferencer(GenInferencer): ...@@ -76,7 +75,6 @@ class ToTInferencer(GenInferencer):
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions', output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None, save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
naive_run: bool = False, naive_run: bool = False,
prompt_wrapper: dict = {}, prompt_wrapper: dict = {},
prompt_sample: str = 'standard', prompt_sample: str = 'standard',
...@@ -97,7 +95,6 @@ class ToTInferencer(GenInferencer): ...@@ -97,7 +95,6 @@ class ToTInferencer(GenInferencer):
output_json_filename=output_json_filename, output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath, output_json_filepath=output_json_filepath,
save_every=save_every, save_every=save_every,
fix_id_list=fix_id_list,
sc_size=n_evaluate_sample, sc_size=n_evaluate_sample,
**kwargs, **kwargs,
) )
...@@ -319,10 +316,7 @@ class ToTInferencer(GenInferencer): ...@@ -319,10 +316,7 @@ class ToTInferencer(GenInferencer):
output_json_filename = self.output_json_filename output_json_filename = self.output_json_filename
# 2. Get results of retrieval process # 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input # 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices( prompt_list = self.get_generation_prompt_list_from_retriever_indices(
......
...@@ -19,6 +19,8 @@ class FixKRetriever(BaseRetriever): ...@@ -19,6 +19,8 @@ class FixKRetriever(BaseRetriever):
Args: Args:
dataset (`BaseDataset`): Any BaseDataset instances. dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used. Attributes of ``reader``, ``train`` and ``test`` will be used.
fix_id_list (List[int]): List of in-context example indices for every
test prompts.
ice_separator (`Optional[str]`): The separator between each in-context ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults example template when origin `PromptTemplate` is provided. Defaults
to '\n'. to '\n'.
...@@ -31,22 +33,19 @@ class FixKRetriever(BaseRetriever): ...@@ -31,22 +33,19 @@ class FixKRetriever(BaseRetriever):
def __init__(self, def __init__(self,
dataset, dataset,
fix_id_list: List[int],
ice_separator: Optional[str] = '\n', ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n', ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None: ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num) super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.fix_id_list = fix_id_list
def retrieve(self, id_list: List[int]): def retrieve(self):
"""Retrieve the in-context example index for each test example. """Retrieve the in-context example index for each test example."""
Args:
id_list (List[int]): List of in-context example indices for every
test prompts.
"""
num_idx = len(self.index_ds) num_idx = len(self.index_ds)
for idx in id_list: for idx in self.fix_id_list:
assert idx < num_idx, f'Index {idx} is out of range of {num_idx}' assert idx < num_idx, f'Index {idx} is out of range of {num_idx}'
rtr_idx_list = [] rtr_idx_list = []
for _ in trange(len(self.test_ds), disable=not self.is_main_process): for _ in trange(len(self.test_ds), disable=not self.is_main_process):
rtr_idx_list.append(id_list) rtr_idx_list.append(self.fix_id_list)
return rtr_idx_list return rtr_idx_list
...@@ -56,6 +56,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: ...@@ -56,6 +56,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items(): for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode()) hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest() return hash_object.hexdigest()
......
...@@ -61,7 +61,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1): ...@@ -61,7 +61,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
infer_cfg = dataset_cfg.get('infer_cfg') infer_cfg = dataset_cfg.get('infer_cfg')
fix_id_list = infer_cfg.inferencer.get('fix_id_list', [])
dataset = build_dataset_from_cfg(dataset_cfg) dataset = build_dataset_from_cfg(dataset_cfg)
ice_template = None ice_template = None
...@@ -76,10 +75,7 @@ def print_prompts(model_cfg, dataset_cfg, count=1): ...@@ -76,10 +75,7 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
infer_cfg['retriever']['dataset'] = dataset infer_cfg['retriever']['dataset'] = dataset
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever']) retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])
if fix_id_list: ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve(fix_id_list)
else:
ice_idx_list = retriever.retrieve()
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \
'Only PPLInferencer and GenInferencer are supported' 'Only PPLInferencer and GenInferencer are supported'
......
...@@ -45,6 +45,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: ...@@ -45,6 +45,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items(): for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode()) hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest() return hash_object.hexdigest()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment