"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "28d4d4728088f551f13edfcafadf12484b32ee64"
Unverified Commit 719ba34d authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Enhancement] Update prompt hash computation (#2)

parent 16e759b9
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
from copy import deepcopy from copy import deepcopy
from typing import Dict, Union from typing import Dict, List, Union
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
...@@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str: ...@@ -24,15 +24,23 @@ def safe_format(input_str: str, **kwargs) -> str:
return input_str return input_str
def get_prompt_hash(dataset_cfg: ConfigDict) -> str: def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
"""Get the hash of the prompt configuration. """Get the hash of the prompt configuration.
Args: Args:
dataset_cfg (ConfigDict): The dataset configuration. dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
configuration.
Returns: Returns:
str: The hash of the prompt configuration. str: The hash of the prompt configuration.
""" """
if isinstance(dataset_cfg, list):
if len(dataset_cfg) == 1:
dataset_cfg = dataset_cfg[0]
else:
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
hash_object = hashlib.sha256(hashes.encode())
return hash_object.hexdigest()
if 'reader_cfg' in dataset_cfg.infer_cfg: if 'reader_cfg' in dataset_cfg.infer_cfg:
# new config # new config
reader_cfg = dict(type='DatasetReader', reader_cfg = dict(type='DatasetReader',
...@@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str: ...@@ -48,7 +56,7 @@ def get_prompt_hash(dataset_cfg: ConfigDict) -> str:
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items(): for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
d_json = json.dumps(dataset_cfg.infer_cfg, sort_keys=True) d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode()) hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest() return hash_object.hexdigest()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment