Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
import pandas as pd
from typing import List, Tuple
# Assuming these are the correct import paths for your framework
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
# Import the provided executor script
# Assuming shared_vis_python_exe.py is in the same directory or accessible via PYTHONPATH
from .python_executor import PythonExecutor
@OPERATOR_REGISTRY.register()
class CodeSandboxSampleEvaluator(OperatorABC):
"""
CodeSandboxSampleEvaluator is an operator that executes code snippets in a secure,
isolated environment to verify their correctness. It leverages a robust
PythonExecutor to handle process isolation, timeouts, and capturing results.
This is the final validation step in the data synthesis pipeline.
"""
def __init__(self, language: str = "python", timeout_length: int = 15, use_process_isolation: bool = True):
"""
Initializes the operator and the underlying PythonExecutor.
Args:
timeout_length: Maximum execution time in seconds for each code snippet.
use_process_isolation: Whether to run code in a separate process for security. Recommended to keep True.
"""
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
# Initialize the PythonExecutor here. It will be reused for all code snippets.
self.executor = PythonExecutor(
get_answer_from_stdout=True, # Capture print statements as primary output
timeout_length=timeout_length,
use_process_isolation=use_process_isolation
)
self.score_name = 'SandboxValidationScore'
self.logger.info(f'{self.__class__.__name__} initialized.')
@staticmethod
def get_desc(lang: str = "en"):
"""
Provides a description of the operator's function and parameters.
"""
if lang == "zh":
return (
"该算子在一个安全的沙箱环境中执行代码片段以验证其正确性。\n\n"
"输入参数:\n"
"- input_code_key: 包含待执行代码的字段名 (默认: 'generated_code')\n"
"输出参数:\n"
"- output_status_key: 用于存储执行状态 ('PASS' 或 'FAIL') 的字段名 (默认: 'sandbox_status')\n"
"- output_log_key: 用于存储执行日志或错误信息的字段名 (默认: 'sandbox_log')\n"
)
else: # Default to English
return (
"This operator executes code snippets in a secure sandbox environment to verify their correctness.\n\n"
"Input Parameters:\n"
"- input_code_key: Field name containing the code to be executed (default: 'generated_code')\n"
"Output Parameters:\n"
"- output_status_key: Field name to store the execution status ('PASS' or 'FAIL') (default: 'sandbox_status')\n"
"- output_log_key: Field name to store the execution log or error message (default: 'sandbox_log')\n"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure the required code column exists and output columns don't.
"""
required_keys = [self.input_key]
forbidden_keys = [self.output_status_key, self.output_log_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s) for CodeSandboxSampleEvaluator: {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten by CodeSandboxSampleEvaluator: {conflict}")
def _score_func(self, code: str) -> Tuple[str, str]:
"""
Execute a single code snippet and return status and log.
Args:
code: Code snippet to execute
Returns:
Tuple of (status, log) where status is 'PASS' or 'FAIL'
"""
try:
result, report = self.executor.execute([code], messages=[])
if report == "Done":
status = "PASS"
log = result.get('text', '') if isinstance(result, dict) else str(result)
else:
status = "FAIL"
log = report
return status, log
except Exception as e:
return "FAIL", f"Execution error: {str(e)}"
def _execute_code_batch(self, code_list: List[str]) -> List[Tuple[str, str]]:
"""
Execute a batch of code snippets using the PythonExecutor.
Args:
code_list: A list of strings, where each string is a code snippet.
Returns:
A list of tuples, where each tuple contains (status, log).
Status can be 'PASS' or 'FAIL', log contains execution output or error message.
"""
# The executor's batch_apply takes a list of code strings and a 'messages' context.
# For our simple validation, the context can be an empty list.
results_with_reports = self.executor.batch_apply(code_list, messages=[])
processed_results = []
for (result, report) in results_with_reports:
# The executor's report tells us about success or failure.
# "Done" indicates success. Anything else (e.g., "Error: ...", "Timeout Error") indicates failure.
if report == "Done":
status = "PASS"
# The 'result' can be a dict with 'text' and/or 'images'. We just need the text log.
log = result.get('text', '') if isinstance(result, dict) else result
else:
status = "FAIL"
# The report itself is the most informative log on failure.
log = report
processed_results.append((status, log))
return processed_results
def eval(self, dataframe: pd.DataFrame, input_key: str) -> Tuple[List[str], List[str]]:
"""
Execute code snippets and return statuses and logs.
Args:
dataframe: Input DataFrame
input_key: Field name containing code snippets
Returns:
Tuple of (statuses, logs) lists
"""
self.logger.info(f"Evaluating {self.score_name}...")
code_list = dataframe[input_key].tolist()
execution_results = self._execute_code_batch(code_list)
statuses, logs = zip(*execution_results)
self.logger.info("Evaluation complete!")
return list(statuses), list(logs)
def run(
self,
storage: DataFlowStorage,
input_key: str,
output_status_key: str = "sandbox_status",
output_log_key: str = "sandbox_log"
):
"""
Executes the sandbox validation process.
Args:
storage: Data storage object
input_key: Field name containing code snippets
output_status_key: Field name for execution status
output_log_key: Field name for execution logs
"""
self.input_key = input_key
self.output_status_key = output_status_key
self.output_log_key = output_log_key
dataframe = storage.read("dataframe")
statuses, logs = self.eval(dataframe, input_key)
dataframe[self.output_status_key] = statuses
dataframe[self.output_log_key] = logs
storage.write(dataframe)
def __del__(self):
"""
Ensures the executor's resources are cleaned up when the operator is destroyed.
"""
if hasattr(self, 'executor') and self.executor:
# The executor's __del__ method handles terminating the worker process.
del self.executor
\ No newline at end of file
import pandas as pd
from typing import Dict
from tqdm import tqdm
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
@OPERATOR_REGISTRY.register()
class CodeTextCompositionSampleEvaluator(OperatorABC):
"""
CodeTextCompositionSampleEvaluator evaluates code samples based on character composition
to provide scores for filtering binary files, encrypted content, and other non-readable text.
It analyzes the ratio of alphabetic and alphanumeric characters to ensure readable content.
"""
# List of languages that require special handling
SPECIAL_LANGS = {"Motorola 68K Assembly", "WebAssembly"}
def __init__(self):
"""
Initialize the operator and set up the logger.
"""
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.score_name = 'CodeTextCompositionScore'
self.logger.info(f'{self.__class__.__name__} initialized.')
@staticmethod
def get_desc(lang: str = "zh"):
"""
Provide operator functionality description and parameter documentation.
"""
if lang == "zh":
return (
"基于字符组成评估代码样本,分析字母字符和字母数字字符的比例。\n\n"
"评估指标:\n"
"- CodeTextCompositionAlphaRatio: 字母字符比例\n"
"- CodeTextCompositionAlnumRatio: 字母数字字符比例\n"
"- CodeTextCompositionScore: 综合字符组成得分 (0-1,1表示通过字符组成检查)\n\n"
"输入要求:需要包含'text'和'language'列\n\n"
"输出参数:\n"
"- CodeTextCompositionAlphaRatio: 字母字符比例\n"
"- CodeTextCompositionAlnumRatio: 字母数字字符比例\n"
"- CodeTextCompositionScore: 综合字符组成得分"
)
else:
return (
"Evaluate code samples based on character composition, analyzing ratios of alphabetic and alphanumeric characters.\n\n"
"Evaluation Metrics:\n"
"- CodeTextCompositionAlphaRatio: Alphabetic character ratio\n"
"- CodeTextCompositionAlnumRatio: Alphanumeric character ratio\n"
"- CodeTextCompositionScore: Comprehensive composition score (0-1, 1 means passes composition checks)\n\n"
"Input Requirement: Requires 'text' and 'language' columns\n\n"
"Output Parameters:\n"
"- CodeTextCompositionAlphaRatio: Alphabetic character ratio\n"
"- CodeTextCompositionAlnumRatio: Alphanumeric character ratio\n"
"- CodeTextCompositionScore: Comprehensive composition score"
)
def _score_func(self, sample):
"""
Calculate composition-based scores for a single code sample.
Args:
sample: Dictionary containing 'text' and 'language' keys
Returns:
Dictionary containing composition scores
"""
text = sample.get('text', '')
language = sample.get('language', '')
# Calculate character ratios
alpha_ratio = sum(c.isalpha() for c in text) / max(1, len(text))
alnum_ratio = sum(c.isalnum() for c in text) / max(1, len(text))
# Calculate comprehensive score (0-1)
score = 1.0
if language in self.SPECIAL_LANGS:
# For assembly languages, check alphanumeric character ratio
if alnum_ratio < 0.25:
score = 0.0
else:
# For normal languages, check alphabetic character ratio
if alpha_ratio < 0.25:
score = 0.0
return {
'CodeTextCompositionAlphaRatio': alpha_ratio,
'CodeTextCompositionAlnumRatio': alnum_ratio,
'CodeTextCompositionScore': score
}
def eval(self, dataframe, input_key):
"""
Evaluate character composition for all samples in the dataframe.
Args:
dataframe: Input DataFrame
input_key: Key containing the sample data
Returns:
List of score dictionaries
"""
scores_list = []
self.logger.info(f"Evaluating {self.score_name}...")
for _, row in dataframe.iterrows():
sample = row[input_key] if isinstance(row[input_key], dict) else {"text": row[input_key], "language": "unknown"}
scores = self._score_func(sample)
scores_list.append(scores)
self.logger.info("Evaluation complete!")
return scores_list
def run(self, storage: DataFlowStorage, input_key: str):
"""
Execute character composition evaluation operation.
Args:
storage: Data storage object
input_key: Key name for input data
"""
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info("CodeTextCompositionScore ready to evaluate.")
scores = self.eval(dataframe, input_key)
# Flatten the nested dictionary of scores into the dataframe
for idx, score_dict in enumerate(scores):
for key, value in score_dict.items():
dataframe.at[idx, key] = value
storage.write(dataframe)
# This code is referenced from PyVision: Agentic Vision with Dynamic Tooling
# GitHub: https://github.com/agents-x-project/PyVision
# Paper: PyVision: Agentic Vision with Dynamic Tooling (arXiv:2507.07998)
import os
import io
import regex
import pickle
import traceback
import copy
import datetime
import dateutil.relativedelta
import multiprocessing
from multiprocessing import Queue, Process
from typing import Any, Dict, Optional, Tuple, List, Union
from tqdm import tqdm
from concurrent.futures import TimeoutError
from contextlib import redirect_stdout
import base64
from io import BytesIO
from PIL import Image
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
plt = None
import numpy as np
import time
import queue
def encode_image(image_path):
"""Encode an image file to base64 string."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def base64_to_image(
base64_str: str,
remove_prefix: bool = True,
convert_mode: Optional[str] = "RGB"
) -> Union[Image.Image, None]:
"""
Convert a Base64-encoded image string to a PIL Image object.
Args:
base64_str: Base64-encoded image string (can include data: prefix)
remove_prefix: Whether to automatically remove the "data:image/..." prefix (default True)
convert_mode: Convert to the specified mode (such as "RGB"/"RGBA", None means no conversion)
Returns:
PIL.Image.Image object, or None if decoding fails
Examples:
>>> img = base64_to_image("data:image/png;base64,iVBORw0KGg...")
>>> img = base64_to_image("iVBORw0KGg...", remove_prefix=False)
"""
try:
# 1. Handle Base64 prefix
if remove_prefix and "," in base64_str:
base64_str = base64_str.split(",")[1]
# 2. Decode Base64
image_data = base64.b64decode(base64_str)
# 3. Convert to PIL Image
image = Image.open(BytesIO(image_data))
# 4. Optional mode conversion
if convert_mode:
image = image.convert(convert_mode)
return image
except (base64.binascii.Error, OSError, Exception) as e:
print(f"Base64 decode failed: {str(e)}")
return None
class PersistentWorker:
"""Persistent worker process."""
# Runtime class registry for pickle-safe serialization
RUNTIME_REGISTRY = {
'ImageRuntime': None, # Will be set later to avoid circular import
'DateRuntime': None,
'ColorObjectRuntime': None,
'GenericRuntime': None,
}
@classmethod
def _get_runtime_class(cls, runtime_identifier):
"""Get runtime class from identifier (class name or class object)."""
if isinstance(runtime_identifier, str):
# String identifier - look up in registry
if runtime_identifier in cls.RUNTIME_REGISTRY:
return cls.RUNTIME_REGISTRY[runtime_identifier]
else:
# Default to ImageRuntime if not found
return cls.RUNTIME_REGISTRY.get('ImageRuntime', ImageRuntime)
elif isinstance(runtime_identifier, type):
# Class object - get its name and look up
class_name = runtime_identifier.__name__
return cls.RUNTIME_REGISTRY.get(class_name, runtime_identifier)
else:
# Default fallback
return cls.RUNTIME_REGISTRY.get('ImageRuntime', ImageRuntime)
@classmethod
def _get_runtime_identifier(cls, runtime_class):
"""Convert runtime class to pickle-safe identifier."""
if runtime_class is None:
return 'ImageRuntime'
elif isinstance(runtime_class, str):
return runtime_class
else:
return runtime_class.__name__
def __init__(self):
self.input_queue = multiprocessing.Queue()
self.output_queue = multiprocessing.Queue()
self.process = None
self.start()
def start(self):
"""Start the worker process."""
self.process = Process(target=self._worker_loop)
self.process.daemon = True
self.process.start()
def _worker_loop(self):
"""Main loop for the worker process."""
runtime = None
runtime_class = None
while True:
try:
# Get task
task = self.input_queue.get()
if task is None: # Termination signal
break
task_type = task.get('type')
if task_type == 'init':
# Initialize runtime
messages = task.get('messages', [])
runtime_identifier = task.get('runtime_class', 'ImageRuntime')
runtime_class = self._get_runtime_class(runtime_identifier)
runtime = runtime_class(messages)
self.output_queue.put({
'status': 'success',
'result': 'Initialized'
})
elif task_type == 'execute':
# Execute code
if runtime is None:
messages = task.get('messages', [])
runtime_identifier = task.get('runtime_class', 'ImageRuntime')
runtime_class = self._get_runtime_class(runtime_identifier)
runtime = runtime_class(messages)
code = task.get('code')
get_answer_from_stdout = task.get('get_answer_from_stdout', True)
answer_symbol = task.get('answer_symbol')
answer_expr = task.get('answer_expr')
try:
# Record the number of images before execution
pre_figures_count = len(runtime._global_vars.get("_captured_figures", []))
if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
runtime.exec_code("\n".join(code))
program_io.seek(0)
result = program_io.read()
elif answer_symbol:
runtime.exec_code("\n".join(code))
result = runtime._global_vars.get(answer_symbol, "")
elif answer_expr:
runtime.exec_code("\n".join(code))
result = runtime.eval_code(answer_expr)
else:
if len(code) > 1:
runtime.exec_code("\n".join(code[:-1]))
result = runtime.eval_code(code[-1])
else:
runtime.exec_code("\n".join(code))
result = ""
# Get newly generated images
all_figures = runtime._global_vars.get("_captured_figures", [])
new_figures = all_figures[pre_figures_count:]
# Build result
if new_figures:
result = {
'text': result,
'images': new_figures
} if result else {'images': new_figures}
else:
result = {'text': result} if result else {}
self.output_queue.put({
'status': 'success',
'result': result,
'report': 'Done'
})
except Exception as e:
self.output_queue.put({
'status': 'error',
'error': str(e),
'traceback': traceback.format_exc(),
'report': f'Error: {str(e)}'
})
elif task_type == 'reset':
# Reset runtime
messages = task.get('messages', [])
runtime_identifier = task.get('runtime_class', 'ImageRuntime')
runtime_class = self._get_runtime_class(runtime_identifier)
runtime = runtime_class(messages)
self.output_queue.put({
'status': 'success',
'result': 'Reset'
})
except Exception as e:
self.output_queue.put({
'status': 'error',
'error': f'Worker error: {str(e)}',
'traceback': traceback.format_exc()
})
def execute(self, code: List[str], messages: list = None, runtime_class=None,
get_answer_from_stdout=True, answer_symbol=None, answer_expr=None, timeout: int = 30):
"""Execute code."""
# Convert runtime class to pickle-safe identifier
runtime_identifier = self._get_runtime_identifier(runtime_class)
self.input_queue.put({
'type': 'execute',
'code': code,
'messages': messages,
'runtime_class': runtime_identifier,
'get_answer_from_stdout': get_answer_from_stdout,
'answer_symbol': answer_symbol,
'answer_expr': answer_expr
})
try:
result = self.output_queue.get(timeout=timeout)
return result
except queue.Empty:
return {
'status': 'error',
'error': 'Execution timeout',
'report': 'Timeout Error'
}
def init_runtime(self, messages: list, runtime_class=None):
"""Initialize runtime."""
# Convert runtime class to pickle-safe identifier
runtime_identifier = self._get_runtime_identifier(runtime_class)
self.input_queue.put({
'type': 'init',
'messages': messages,
'runtime_class': runtime_identifier
})
return self.output_queue.get()
def reset_runtime(self, messages: list = None, runtime_class=None):
"""Reset runtime."""
# Convert runtime class to pickle-safe identifier
runtime_identifier = self._get_runtime_identifier(runtime_class)
self.input_queue.put({
'type': 'reset',
'messages': messages,
'runtime_class': runtime_identifier
})
return self.output_queue.get()
def terminate(self):
"""Terminate the worker process."""
if self.process and self.process.is_alive():
self.input_queue.put(None)
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
self._captured_figures = []
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
# Security check
if regex.search(r"(\s|^)?(input|os\.system|subprocess)\(", code_piece):
raise RuntimeError("Forbidden function calls detected")
# Detect and modify plt.show() calls
if "plt.show()" in code_piece and MATPLOTLIB_AVAILABLE:
modified_code = code_piece.replace("plt.show()", """
# Capture current figure
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
_captured_image = base64.b64encode(buf.read()).decode('utf-8')
_captured_figures.append(_captured_image)
plt.close()
""")
# Ensure _captured_figures variable exists
if "_captured_figures" not in self._global_vars:
self._global_vars["_captured_figures"] = []
exec(modified_code, self._global_vars)
else:
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
def inject(self, var_dict: Dict[str, Any]) -> None:
for k, v in var_dict.items():
self._global_vars[k] = v
@property
def answer(self):
return self._global_vars.get("answer", None)
@property
def captured_figures(self):
return self._global_vars.get("_captured_figures", [])
class ImageRuntime(GenericRuntime):
HEADERS = [
"""try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
plt = None
""",
"from PIL import Image",
"import io",
"import base64",
"import numpy as np",
"_captured_figures = []", # Initialize image capture list
]
def __init__(self, messages):
super().__init__()
image_var_dict = {}
image_var_idx = 0
init_captured_figures = []
for message_item in messages:
content = message_item['content']
for item in content:
if isinstance(item, dict):
item_type = item.get('type')
if item_type == "image_url":
item_image_url = item['image_url']['url']
image = base64_to_image(item_image_url)
if image:
image_var_dict[f"image_clue_{image_var_idx}"] = image
init_captured_figures.append(base64.b64encode(
BytesIO(image.tobytes()).getvalue()).decode('utf-8'))
image_var_idx += 1
image_var_dict["_captured_figures"] = init_captured_figures
self.inject(image_var_dict)
class DateRuntime(GenericRuntime):
GLOBAL_DICT = {}
HEADERS = [
"import datetime",
"from dateutil.relativedelta import relativedelta",
"timedelta = relativedelta"
]
class CustomDict(dict):
def __iter__(self):
return list(super().__iter__()).__iter__()
class ColorObjectRuntime(GenericRuntime):
GLOBAL_DICT = {"dict": CustomDict}
class PythonExecutor:
def __init__(
self,
runtime_class=None,
get_answer_symbol: Optional[str] = None,
get_answer_expr: Optional[str] = None,
get_answer_from_stdout: bool = True,
timeout_length: int = 20,
use_process_isolation: bool = True,
) -> None:
self.runtime_class = runtime_class if runtime_class else ImageRuntime
self.answer_symbol = get_answer_symbol
self.answer_expr = get_answer_expr
self.get_answer_from_stdout = get_answer_from_stdout
self.timeout_length = timeout_length
self.use_process_isolation = use_process_isolation
self.persistent_worker = None
def _ensure_worker(self):
"""Ensure the worker process exists."""
if self.persistent_worker is None:
self.persistent_worker = PersistentWorker()
def process_generation_to_code(self, gens: str):
return [g.split("\n") for g in gens]
def execute(
self,
code,
messages,
get_answer_from_stdout=True,
runtime_class=None,
answer_symbol=None,
answer_expr=None,
) -> Tuple[Union[str, Dict[str, Any]], str]:
if self.use_process_isolation:
# Ensure worker process exists
self._ensure_worker()
# Execute code
result = self.persistent_worker.execute(
code,
messages,
runtime_class or self.runtime_class,
get_answer_from_stdout,
answer_symbol,
answer_expr,
timeout=self.timeout_length
)
if result['status'] == 'success':
return result['result'], result.get('report', 'Done')
else:
error_result = {
'error': result.get('error', 'Unknown error'),
'traceback': result.get('traceback', '')
}
return error_result, result.get('report', f"Error: {result.get('error', 'Unknown error')}")
else:
# Non-isolation mode (for backward compatibility)
runtime = runtime_class(messages) if runtime_class else self.runtime_class(messages)
try:
if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
runtime.exec_code("\n".join(code))
program_io.seek(0)
result = program_io.read()
elif answer_symbol:
runtime.exec_code("\n".join(code))
result = runtime._global_vars.get(answer_symbol, "")
elif answer_expr:
runtime.exec_code("\n".join(code))
result = runtime.eval_code(answer_expr)
else:
if len(code) > 1:
runtime.exec_code("\n".join(code[:-1]))
result = runtime.eval_code(code[-1])
else:
runtime.exec_code("\n".join(code))
result = ""
# Check for captured figures
captured_figures = runtime.captured_figures
if captured_figures:
result = {
'text': result,
'images': captured_figures
} if result else {'images': captured_figures}
else:
result = {'text': result} if result else {}
report = "Done"
except Exception as e:
result = {
'error': str(e),
'traceback': traceback.format_exc()
}
report = f"Error: {str(e)}"
return result, report
def apply(self, code, messages):
return self.batch_apply([code], messages)[0]
@staticmethod
def truncate(s, max_length=400):
if isinstance(s, dict):
# If it is a dict (with images), truncate only the text part
if 'text' in s:
half = max_length // 2
if len(s['text']) > max_length:
s['text'] = s['text'][:half] + "..." + s['text'][-half:]
return s
else:
half = max_length // 2
if isinstance(s, str) and len(s) > max_length:
s = s[:half] + "..." + s[-half:]
return s
def batch_apply(self, batch_code, messages):
all_code_snippets = self.process_generation_to_code(batch_code)
timeout_cnt = 0
all_exec_results = []
if len(all_code_snippets) > 100:
progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
else:
progress_bar = None
for code in all_code_snippets:
try:
result = self.execute(
code,
messages=messages,
get_answer_from_stdout=self.get_answer_from_stdout,
runtime_class=self.runtime_class,
answer_symbol=self.answer_symbol,
answer_expr=self.answer_expr,
)
all_exec_results.append(result)
except TimeoutError as error:
print(error)
all_exec_results.append(("", "Timeout Error"))
timeout_cnt += 1
except Exception as error:
print(f"Error in batch_apply: {error}")
all_exec_results.append(("", f"Error: {str(error)}"))
if progress_bar is not None:
progress_bar.update(1)
if progress_bar is not None:
progress_bar.close()
batch_results = []
for code, (res, report) in zip(all_code_snippets, all_exec_results):
# Handle results
if isinstance(res, dict):
# If result contains images, special handling
if 'text' in res:
res['text'] = str(res['text']).strip()
res['text'] = self.truncate(res['text'])
report = str(report).strip()
report = self.truncate(report)
else:
# Normal text result
res = str(res).strip()
res = self.truncate(res)
report = str(report).strip()
report = self.truncate(report)
batch_results.append((res, report))
return batch_results
def reset(self, messages=None):
"""Reset executor state."""
if self.use_process_isolation and self.persistent_worker:
self.persistent_worker.reset_runtime(messages, self.runtime_class)
def __del__(self):
"""Clean up resources."""
if self.persistent_worker:
self.persistent_worker.terminate()
# Initialize runtime registry after all classes are defined
PersistentWorker.RUNTIME_REGISTRY['ImageRuntime'] = ImageRuntime
PersistentWorker.RUNTIME_REGISTRY['DateRuntime'] = DateRuntime
PersistentWorker.RUNTIME_REGISTRY['ColorObjectRuntime'] = ColorObjectRuntime
PersistentWorker.RUNTIME_REGISTRY['GenericRuntime'] = GenericRuntime
\ No newline at end of file
import pandas as pd
import numpy as np
from typing import List, Callable, Optional
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.operators.code import CodeAutoGeneratedSampleEvaluator
@OPERATOR_REGISTRY.register()
class CodeAutoGeneratedFilter(OperatorABC):
"""
CodeAutoGeneratedFilter filters auto-generated code files using
CodeAutoGeneratedSampleEvaluator scores to ensure only human-written code is retained.
"""
def __init__(self, min_score: float = 1.0, max_score: float = 1.0, is_generated_func: Optional[Callable[[], bool]] = None):
"""
Initialize the operator with evaluator and thresholds.
Args:
min_score: Minimum auto-generation score threshold
max_score: Maximum auto-generation score threshold
is_generated_func: Optional external detection function
"""
self.min_score = min_score
self.max_score = max_score
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} with min_score: {self.min_score} and max_score: {self.max_score}...")
self.scorer = CodeAutoGeneratedSampleEvaluator(is_generated_func)
@staticmethod
def get_desc(lang: str = "en"):
"""
Provide operator functionality description and parameter documentation.
"""
if lang == "zh":
return (
"基于CodeAutoGeneratedSampleEvaluator的得分过滤自动生成的代码文件,确保只保留人工编写的代码。\n\n"
"评估指标:\n"
"- 自动生成标记数量:检测文件前5行中的自动生成标记\n"
"- 检测标记:'auto-generated', 'autogenerated', 'automatically generated'等\n"
"- 综合自动生成得分:0-1,1表示非自动生成\n"
"- 支持外部检测函数进行额外验证\n\n"
"输入参数:\n"
"- input_key: 输入字段名(需要包含'lines'列)\n"
"- output_key: 输出标签字段名 (默认: 'auto_generated_filter_label')\n"
"- min_score: 最小自动生成得分阈值 (默认: 1.0)\n"
"- max_score: 最大自动生成得分阈值 (默认: 1.0)\n"
"- is_generated_func: 可选的外部检测函数\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留自动生成得分在指定范围内的代码样本\n"
"- 返回包含自动生成得分标签字段名的列表"
)
else:
return (
"Filter auto-generated code files using scores from CodeAutoGeneratedSampleEvaluator to ensure only human-written code is retained.\n\n"
"Evaluation Metrics:\n"
"- Auto-generation marker count: Detect markers in first 5 lines\n"
"- Detect markers: 'auto-generated', 'autogenerated', 'automatically generated', etc.\n"
"- Comprehensive auto-generation score: 0-1, 1 means not auto-generated\n"
"- Support external detection functions for additional validation\n\n"
"Input Parameters:\n"
"- input_key: Input field name (requires 'lines' column)\n"
"- output_key: Output label field name (default: 'auto_generated_filter_label')\n"
"- min_score: Minimum auto-generation score threshold (default: 1.0)\n"
"- max_score: Maximum auto-generation score threshold (default: 1.0)\n"
"- is_generated_func: Optional external detection function\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only code samples with auto-generation scores within specified range\n"
"- List containing auto-generation score label field name"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validate DataFrame to ensure required columns exist.
"""
required_keys = [self.input_lines_key]
missing = [k for k in required_keys if k not in dataframe.columns]
if missing:
raise ValueError(f"AutogenFilter missing required columns: {missing}")
def run(
self,
storage: DataFlowStorage,
input_key: str,
output_key: str = "auto_generated_filter_label"
) -> List[str]:
"""
Execute auto-generated code detection and filtering using evaluator scores.
Args:
storage: Data storage object
input_key: Field name containing code lines
output_key: Key name for output label
Returns:
List[str]: List containing output key name
"""
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
dataframe = storage.read("dataframe")
scores = self.scorer.eval(dataframe, self.input_key)
# Add scores to dataframe
for idx, score_dict in enumerate(scores):
for key, value in score_dict.items():
dataframe.at[idx, key] = value
# Apply filtering based on CodeAutoGeneratedScore
results = np.ones(len(dataframe), dtype=int)
score_filter = (dataframe["CodeAutoGeneratedScore"] >= self.min_score) & (dataframe["CodeAutoGeneratedScore"] <= self.max_score)
nan_filter = np.isnan(dataframe["CodeAutoGeneratedScore"])
metric_filter = score_filter | nan_filter
results = results & metric_filter.astype(int)
self.logger.debug(f"Filtered by auto-generated score, {np.sum(results)} data remained")
dataframe[f"{self.output_key}"] = metric_filter.astype(int)
filtered_dataframe = dataframe[results == 1]
storage.write(filtered_dataframe)
self.logger.info(f"Filtering completed. Total records passing filter: {len(filtered_dataframe)}.")
return [self.output_key]
\ No newline at end of file
import pandas as pd
import numpy as np
from typing import List, Dict, Any
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.operators.code.eval.code_document_quality_sample_evaluator import CodeDocumentQualitySampleEvaluator
import re
from collections import Counter
@OPERATOR_REGISTRY.register()
class CodeDocumentQualityFilter(OperatorABC):
"""
CodeDocumentQualityFilter applies comprehensive document-level quality filtering
rules using CodeDocumentQualitySampleEvaluator scores to remove low-quality code and text samples.
"""
def __init__(self, min_score: float = 1.0, max_score: float = 1.0, thresholds: Dict[str, Any] = None):
"""
Initialize the operator with evaluator and thresholds.
Args:
min_score: Minimum document quality score threshold
max_score: Maximum document quality score threshold
thresholds: Optional thresholds dictionary to override default thresholds
"""
self.min_score = min_score
self.max_score = max_score
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} with min_score: {self.min_score} and max_score: {self.max_score}...")
self.scorer = CodeDocumentQualitySampleEvaluator(thresholds)
@staticmethod
def get_desc(lang: str = "en"):
if lang == "zh":
return (
"基于CodeDocumentQualitySampleEvaluator的得分应用综合文档级质量过滤规则,移除低质量代码和文本样本。\n\n"
"评估指标:\n"
"- 内容长度:字符数、词数、行数范围检查\n"
"- 重复模式:重复行比例、2-10gram重复比例\n"
"- 字符组成:花括号比例、全大写单词比例\n"
"- 文本熵值:单字符熵值检查\n"
"- 综合文档质量得分:0-1,1表示通过所有质量检查\n\n"
"输入参数:\n"
"- input_key: 输入字段名(需要包含'text'、'filename'、'language'列)\n"
"- output_key: 输出标签字段名 (默认: 'doc_quality_filter_label')\n"
"- min_score: 最小文档质量得分阈值 (默认: 1.0)\n"
"- max_score: 最大文档质量得分阈值 (默认: 1.0)\n"
"- thresholds: 可选的阈值字典,用于覆盖默认阈值\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留文档质量得分在指定范围内的样本\n"
"- 返回包含文档质量得分标签字段名的列表"
)
else:
return (
"Apply comprehensive document-level quality filtering rules using scores from CodeDocumentQualitySampleEvaluator to remove low-quality code and text samples.\n\n"
"Evaluation Metrics:\n"
"- Content length: character/word/line count range checks\n"
"- Repetition patterns: duplicate line ratio, 2-10gram repetition ratios\n"
"- Character composition: curly bracket ratio, all-caps word ratio\n"
"- Text entropy: unigram entropy checks\n"
"- Comprehensive document quality score: 0-1, 1 means passes all quality checks\n\n"
"Input Parameters:\n"
"- input_key: Input field name (requires 'text', 'filename', 'language' columns)\n"
"- output_key: Output label field name (default: 'doc_quality_filter_label')\n"
"- min_score: Minimum document quality score threshold (default: 1.0)\n"
"- max_score: Maximum document quality score threshold (default: 1.0)\n"
"- thresholds: Optional thresholds dictionary to override default thresholds\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only samples with document quality scores within specified range\n"
"- List containing document quality score label field name"
)
def _num_chars(self, text: str) -> int:
return len(text)
def _num_words(self, text: str) -> int:
return len(re.findall(r'\w+', text))
def _num_lines(self, text: str) -> int:
return len(text.splitlines())
def _frac_duplicate_lines(self, lines: List[str]) -> float:
if not lines:
return 0.0
line2count = Counter([l.strip() for l in lines if l.strip()])
count = sum([v for v in line2count.values() if v != 1])
total = sum([v for v in line2count.values()])
return count / total if total else 0.0
def _frac_curly_bracket(self, text: str) -> float:
total = len(text)
if total == 0:
return 0.0
count = text.count('{') + text.count('}')
return count / total
def _frac_all_caps_words(self, text: str) -> float:
words = re.findall(r'\b\w+\b', text)
if not words:
return 0.0
all_caps = [w for w in words if w.isupper() and len(w) > 1]
return len(all_caps) / len(words)
def _entropy_unigram(self, text: str) -> float:
words = re.findall(r'\b\w+\b', text)
if not words:
return 0.0
freq = Counter(words)
total = sum(freq.values())
import math
entropy = -sum((c/total) * math.log2(c/total) for c in freq.values())
return entropy
def _frac_duplicate_ngrams(self, text: str, n: int = 5) -> float:
words = re.findall(r'\b\w+\b', text)
if len(words) < n:
return 0.0
ngrams = [' '.join(words[i:i+n]) for i in range(len(words)-n+1)]
ngram2count = Counter(ngrams)
count = sum([v for v in ngram2count.values() if v != 1])
total = sum([v for v in ngram2count.values()])
return count / total if total else 0.0
def _num_sentences(self, text: str) -> int:
SENT_PATTERN = re.compile(r'\b[^.!?。!?؟]+[.!?。!?؟]*', flags=re.UNICODE)
return len(SENT_PATTERN.findall(text))
def _num_lines(self, text: str) -> int:
return len(text.splitlines())
def _mean_word_length(self, text: str) -> float:
words = re.findall(r'\b\w+\b', text)
if not words:
return 0.0
return sum(len(w) for w in words) / len(words)
def _frac_full_bracket(self, text: str) -> float:
words = re.findall(r'\b\w+\b', text)
if not words:
return 0.0
count = sum(1 for w in words if w in ('【', '】'))
return count / len(words)
def _frac_lines_end_with_readmore(self, text: str) -> float:
ellipsis = ("...", "…", '全文', '详情', '详细', '更多', 'المزيد', 'تفاصيل', 'اقرأ المزيد', 'もっと', '詳細', 'もっと読む')
lines = text.splitlines()
if not lines:
return 0.0
total_ellipsis_lines = sum(
any(l.rstrip().rstrip(']】)>》').endswith(e) for e in ellipsis)
for l in lines
)
return total_ellipsis_lines / len(lines)
def _frac_lines_start_with_bullet(self, text: str) -> float:
# Common bullet symbols
bullets = ("-", "*", "•", "·", "●", "▪", "‣", "⁃", "◦", "‧", "﹒", "・", "∙", "‣", "➤", "➢", "➣", "➥", "➦", "➧", "➨", "➩", "➪", "➫", "➬", "➭", "➮", "➯", "➱", "➲", "➳", "➵", "➸", "➺", "➻", "➼", "➽", "➾")
lines = text.splitlines()
if not lines:
return 0.0
total_bullet_lines = sum(any(l.lstrip().startswith(b) for b in bullets) for l in lines)
return total_bullet_lines / len(lines)
def _frac_words_unique(self, text: str) -> float:
words = re.findall(r'\b\w+\b', text)
if not words:
return 0.0
return len(set(words)) / len(words)
def _frac_replacement_symbols(self, text: str) -> float:
total = len(text)
if total == 0:
return 0.0
return text.count('�') / total
def _mean_sentence_length(self, text: str) -> float:
sentences = re.split(r'\.|\?|\!|\n', text)
if not sentences:
return 0.0
return sum(len(s) for s in sentences) / len(sentences)
def _frac_chars_url_html(self, text: str) -> float:
total = len(text)
if total == 0:
return 0.0
link_pattern = r'\(https?://\S+\)'
html_tag_pattern = r'<.*?>'
link_list = re.findall(link_pattern, text)
html_tag_list = re.findall(html_tag_pattern, text)
url_char_num = sum(len(link) for link in link_list)
html_tag_char_num = sum(len(tag) for tag in html_tag_list)
return (url_char_num + html_tag_char_num) / total
def _frac_chars_alphabet(self, text: str, lang: str = 'en') -> float:
if not text or lang != 'en':
return 0.0
return sum(c.isalpha() for c in text) / len(text)
def _frac_chars_digital(self, text: str) -> float:
if not text:
return 0.0
return sum(c.isdigit() for c in text) / len(text)
def _frac_chars_whitespace(self, text: str) -> float:
if not text:
return 0.0
return len(re.findall(r'\s', text)) / len(text)
def _frac_chars_hex_words(self, text: str) -> float:
total = len(text)
if total == 0:
return 0.0
count = sum(len(e) for e in re.findall(r'\b0[xX][0-9a-fA-F]+\b', text))
return count / total
def _is_code_related_filename(self, filename: str) -> bool:
related = ["readme", "notes", "todo", "description", "cmakelists"]
name = filename.split('.')[0].lower()
return (
"requirement" in name or name in related or name == "read.me"
)
def _apply_rules(self, row: pd.Series, thresholds: Dict[str, Any]) -> bool:
text = row.get('text', row.get('content', ''))
filename = row.get('filename', '')
lang = row.get('language', 'en')
lines = text.splitlines()
# Rule 1: min/max chars
num_chars = self._num_chars(text)
if num_chars < thresholds['min_num_chars'] or num_chars > thresholds['max_num_chars']:
return False
# Rule 2: min/max words
num_words = self._num_words(text)
if num_words < thresholds['min_num_words'] or num_words > thresholds['max_num_words']:
return False
# Rule 3: duplicate lines
frac_dup_lines = self._frac_duplicate_lines(lines)
if frac_dup_lines > thresholds['max_frac_duplicate_lines']:
return False
# Rule 4: duplicate n-grams (2~10)
for n in range(2, 11):
key = f'max_frac_duplicate_{n}gram'
if key in thresholds:
frac_dup_ngram = self._frac_duplicate_ngrams_n(text, n=n)
if frac_dup_ngram > thresholds[key]:
return False
# Rule 5: curly bracket ratio
frac_curly = self._frac_curly_bracket(text)
if frac_curly > thresholds['max_frac_curly_bracket']:
return False
# Rule 6: all caps words
frac_caps = self._frac_all_caps_words(text)
if frac_caps > thresholds['max_frac_all_caps_words']:
return False
# Rule 7: unigram entropy
entropy = self._entropy_unigram(text)
if entropy < thresholds['min_entropy_unigram']:
return False
# 其它规则同前
return True
def run(self, storage: DataFlowStorage, input_key: str, output_key: str = "doc_quality_filter_label") -> List[str]:
"""
Applies document-level quality filtering rules using evaluator scores.
"""
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
dataframe = storage.read("dataframe")
scores = self.scorer.eval(dataframe, self.input_key)
# Add scores to dataframe
for idx, score_dict in enumerate(scores):
for key, value in score_dict.items():
dataframe.at[idx, key] = value
# Apply filtering based on CodeDocumentQualityScore
results = np.ones(len(dataframe), dtype=int)
score_filter = (dataframe["CodeDocumentQualityScore"] >= self.min_score) & (dataframe["CodeDocumentQualityScore"] <= self.max_score)
nan_filter = np.isnan(dataframe["CodeDocumentQualityScore"])
metric_filter = score_filter | nan_filter
results = results & metric_filter.astype(int)
self.logger.debug(f"Filtered by document quality score, {np.sum(results)} data remained")
dataframe[f"{self.output_key}"] = metric_filter.astype(int)
filtered_dataframe = dataframe[results == 1]
storage.write(filtered_dataframe)
self.logger.info(f"Filtering completed. Total records passing filter: {len(filtered_dataframe)}.")
return [self.output_key]
import re
import pandas as pd
import numpy as np
from typing import List
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.operators.code.eval.code_encoded_data_sample_evaluator import CodeEncodedDataSampleEvaluator
@OPERATOR_REGISTRY.register()
class CodeEncodedDataFilter(OperatorABC):
"""
CodeEncodedDataFilter filters code samples based on encoded data patterns using
CodeEncodedDataSampleEvaluator scores. It removes binary content and auto-generated code.
"""
def __init__(self, min_score: float = 1.0, max_score: float = 1.0):
"""
Initialize the operator with evaluator and thresholds.
"""
self.min_score = min_score
self.max_score = max_score
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} with min_score: {self.min_score} and max_score: {self.max_score}...")
self.scorer = CodeEncodedDataSampleEvaluator()
@staticmethod
def get_desc(lang: str = "en"):
"""
Provide operator functionality description and parameter documentation.
"""
if lang == "zh":
return (
"基于CodeEncodedDataSampleEvaluator的得分过滤代码样本,移除二进制内容和自动生成代码。\n\n"
"评估指标:\n"
"- Base64编码数据比例:检测连续64+字符的Base64字符串\n"
"- 十六进制数据比例:检测8+个连续的十六进制对\n"
"- Unicode转义序列比例:检测8+个连续的\\uXXXX序列\n"
"- 综合编码数据得分:0-1,1表示通过检查\n\n"
"输入参数:\n"
"- input_key: 输入字段名(需要包含'text'列)\n"
"- output_key: 输出标签字段名 (默认: 'encoded_data_filter_label')\n"
"- min_score: 最小编码数据得分阈值 (默认: 1.0)\n"
"- max_score: 最大编码数据得分阈值 (默认: 1.0)\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留编码数据得分在指定范围内的代码样本\n"
"- 返回包含编码数据得分标签字段名的列表"
)
else:
return (
"Filter code samples using scores from CodeEncodedDataSampleEvaluator to remove binary content and auto-generated code.\n\n"
"Evaluation Metrics:\n"
"- Base64 encoded data ratio: Detect 64+ consecutive Base64 characters\n"
"- Hexadecimal data ratio: Detect 8+ consecutive hex pairs\n"
"- Unicode escape sequence ratio: Detect 8+ consecutive \\uXXXX sequences\n"
"- Comprehensive encoded data score: 0-1, 1 means passes checks\n\n"
"Input Parameters:\n"
"- input_key: Input field name (requires 'text' column)\n"
"- output_key: Output label field name (default: 'encoded_data_filter_label')\n"
"- min_score: Minimum encoded data score threshold (default: 1.0)\n"
"- max_score: Maximum encoded data score threshold (default: 1.0)\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only code samples with encoded data scores within specified range\n"
"- List containing encoded data score label field name"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validate DataFrame to ensure required columns exist.
"""
required_keys = [self.input_text_key]
missing = [k for k in required_keys if k not in dataframe.columns]
if missing:
raise ValueError(f"DataPatternFilter missing required columns: {missing}")
def run(
self,
storage: DataFlowStorage,
input_key: str,
output_key: str = "encoded_data_filter_label"
) -> List[str]:
"""
Execute data pattern detection and filtering using evaluator scores.
Args:
storage: Data storage object
input_key: Field name containing code text
output_key: Key name for output label
Returns:
List[str]: List containing output key name
"""
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
dataframe = storage.read("dataframe")
scores = self.scorer.eval(dataframe, self.input_key)
# Add scores to dataframe
for idx, score_dict in enumerate(scores):
for key, value in score_dict.items():
dataframe.at[idx, key] = value
# Apply filtering based on CodeEncodedDataScore
results = np.ones(len(dataframe), dtype=int)
score_filter = (dataframe["CodeEncodedDataScore"] >= self.min_score) & (dataframe["CodeEncodedDataScore"] <= self.max_score)
nan_filter = np.isnan(dataframe["CodeEncodedDataScore"])
metric_filter = score_filter | nan_filter
results = results & metric_filter.astype(int)
self.logger.debug(f"Filtered by encoded data score, {np.sum(results)} data remained")
dataframe[f"{self.output_key}"] = metric_filter.astype(int)
filtered_dataframe = dataframe[results == 1]
storage.write(filtered_dataframe)
self.logger.info(f"Filtering completed. Total records passing filter: {len(filtered_dataframe)}.")
return [self.output_key]
\ No newline at end of file
import pandas as pd
from typing import List, Set
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
@OPERATOR_REGISTRY.register()
class CodeFileTypeContentFilter(OperatorABC):
"""
CodeFileTypeContentFilter filters code samples based on file types and content characteristics,
applying different rules for different file formats to ensure quality and relevance.
This filter directly applies filtering rules without using evaluator scores:
- Removes oversized text files (>512 lines for Text/JSON/YAML files)
- Removes HTML files with insufficient visible text content
- Removes text files with inappropriate filenames (not documentation-related)
- Keeps files that meet format-specific quality criteria
"""
# File types that require size checking
SIZE_CHECK_TYPES: Set[str] = {
"text", "json", "yaml", "web ontology language",
"graphviz", "dot"
}
# Valid filename set for Text files
VALID_TEXT_NAMES: Set[str] = {
"readme", "notes", "todo", "description", "cmakelists"
}
def __init__(self):
"""
Initialize the operator and set up the logger.
"""
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "en"):
"""
Provide operator functionality description and parameter documentation.
"""
if lang == "zh":
return (
"基于文件类型和内容特征直接过滤代码样本,针对不同文件格式应用特定规则。\n\n"
"过滤规则:\n"
"- Text/JSON/YAML/Graphviz文件:行数 > 512 行\n"
"- HTML文件:可见文本长度 < 100字符 或 可见文本比例 < 20%\n"
"- Text文件:文件名不符合文档规范(非readme/notes/todo等)\n\n"
"输入参数:\n"
"- input_key: 输入字段名(需要包含'filetype'、'filename'、'line_count'等列)\n"
"- output_key: 输出标签字段名 (默认: 'file_type_content_filter_label')\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留符合文件类型规则的样本\n"
"- 返回包含输出标签字段名的列表"
)
else:
return (
"Filter code samples based on file types and content characteristics, applying specific rules for different file formats.\n\n"
"Filtering Rules:\n"
"- Text/JSON/YAML/Graphviz files: line count > 512\n"
"- HTML files: visible text length < 100 chars OR visible text ratio < 20%\n"
"- Text files: filename doesn't follow documentation conventions (not readme/notes/todo etc.)\n\n"
"Input Parameters:\n"
"- input_key: Input field name (requires 'filetype', 'filename', 'line_count' columns)\n"
"- output_key: Output label field name (default: 'file_type_content_filter_label')\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only samples meeting file type rules\n"
"- List containing output label field name"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validate DataFrame to ensure required columns exist.
"""
required_keys = ["filetype", "filename", "line_count"]
missing = [k for k in required_keys if k not in dataframe.columns]
if missing:
raise ValueError(f"{self.__class__.__name__} missing required columns: {missing}")
def _is_large_file(self, row: pd.Series) -> bool:
"""
Check if the file is large (line count > 512).
"""
return row.get("line_count", 0) > 512
def _is_html_valid(self, row: pd.Series) -> bool:
"""
Check if HTML file meets visible text requirements.
"""
visible_text_len = row.get("visible_text_length", 0)
total_code_len = row.get("total_code_length", 1)
ratio = visible_text_len / max(total_code_len, 1)
return visible_text_len >= 100 and ratio >= 0.2
def _is_text_filename_valid(self, filename: str) -> bool:
"""
Check if Text filename meets requirements.
"""
filename_lower = filename.lower()
name_without_ext = filename_lower.rsplit('.', 1)[0]
return (
"requirement" in filename_lower
or name_without_ext in self.VALID_TEXT_NAMES
)
def run(
self,
storage: DataFlowStorage,
input_key: str,
output_key: str = "file_type_content_filter_label"
) -> List[str]:
"""
Execute file type filtering operation.
Args:
storage: Data storage object
input_key: Key name for input data
output_key: Key name for output label
Returns:
List[str]: List containing output key name
"""
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
# 1. Read data
dataframe = storage.read("dataframe")
if dataframe.empty:
self.logger.warning("Input data is empty, skipping processing.")
storage.write(dataframe)
return [self.output_key]
original_count = len(dataframe)
# 2. Validate data
self._validate_dataframe(dataframe)
# 3. Define filtering logic
def filter_row(row: pd.Series) -> bool:
filetype = row.get("filetype", "").lower()
filename = row.get("filename", "")
if filetype in self.SIZE_CHECK_TYPES:
return not self._is_large_file(row)
elif filetype == "html":
return self._is_html_valid(row)
elif filetype == "text":
return self._is_text_filename_valid(filename)
return True
# 4. Apply filtering and add label
filter_mask = dataframe.apply(filter_row, axis=1)
dataframe[self.output_key] = filter_mask.astype(int)
filtered_df = dataframe[filter_mask].reset_index(drop=True)
# 5. Count results
filtered_count = len(filtered_df)
self.logger.info(f"Filtering completed. Total records passing filter: {filtered_count}.")
# 6. Write back results
storage.write(filtered_df)
return [self.output_key]
\ No newline at end of file
import pandas as pd
import numpy as np
from typing import List
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.operators.code.eval.code_length_sample_evaluator import CodeLengthSampleEvaluator
@OPERATOR_REGISTRY.register()
class CodeLengthSampleFilter(OperatorABC):
"""
CodeLengthSampleFilter filters code samples based on length characteristics using
CodeLengthSampleEvaluator scores. It removes oversized files and poorly formatted code.
"""
def __init__(self, min_score: float = 1.0, max_score: float = 1.0):
"""
Initialize the operator with evaluator and thresholds.
"""
self.min_score = min_score
self.max_score = max_score
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} with min_score: {self.min_score} and max_score: {self.max_score}...")
self.scorer = CodeLengthSampleEvaluator()
@staticmethod
def get_desc(lang: str = "en"):
"""
Provide operator functionality description and parameter documentation.
"""
if lang == "zh":
return (
"基于CodeLengthSampleEvaluator的得分过滤代码样本,移除超大文件和格式不良的代码。\n\n"
"评估指标:\n"
"- 总行数:检查是否超过100,000行\n"
"- 平均行长:普通语言>100字符,特殊语言>100,000字符\n"
"- 最大行长:普通语言>1,000字符\n\n"
"输入参数:\n"
"- input_key: 输入字段名(需要包含'lines'和'language'列)\n"
"- output_key: 输出标签字段名 (默认: 'length_filter_label')\n"
"- min_score: 最小长度得分阈值 (默认: 1.0)\n"
"- max_score: 最大长度得分阈值 (默认: 1.0)\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留长度得分在指定范围内的代码样本\n"
"- 返回包含长度得分标签字段名的列表"
)
else:
return (
"Filter code samples using scores from CodeLengthSampleEvaluator to remove oversized files and poorly formatted code.\n\n"
"Evaluation Metrics:\n"
"- Total lines: Check if exceeds 100,000 lines\n"
"- Average line length: Normal languages >100 chars, special languages >100,000 chars\n"
"- Maximum line length: Normal languages >1,000 chars\n\n"
"Input Parameters:\n"
"- input_key: Input field name (requires 'lines' and 'language' columns)\n"
"- output_key: Output label field name (default: 'length_filter_label')\n"
"- min_score: Minimum length score threshold (default: 1.0)\n"
"- max_score: Maximum length score threshold (default: 1.0)\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only code samples with length scores within specified range\n"
"- List containing length score label field name"
)
def run(self, storage: DataFlowStorage, input_key: str, output_key: str = 'length_filter_label'):
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
dataframe = storage.read("dataframe")
scores = self.scorer.eval(dataframe, self.input_key)
# Add scores to dataframe
for idx, score_dict in enumerate(scores):
for key, value in score_dict.items():
dataframe.at[idx, key] = value
# Apply filtering based on CodeLengthScore
results = np.ones(len(dataframe), dtype=int)
score_filter = (dataframe["CodeLengthScore"] >= self.min_score) & (dataframe["CodeLengthScore"] <= self.max_score)
nan_filter = np.isnan(dataframe["CodeLengthScore"])
metric_filter = score_filter | nan_filter
results = results & metric_filter.astype(int)
self.logger.debug(f"Filtered by length score, {np.sum(results)} data remained")
dataframe[f"{self.output_key}"] = metric_filter.astype(int)
filtered_dataframe = dataframe[results == 1]
storage.write(filtered_dataframe)
self.logger.info(f"Filtering completed. Total records passing filter: {len(filtered_dataframe)}.")
return [self.output_key]
\ No newline at end of file
import pandas as pd
import numpy as np
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.core import LLMServingABC
from dataflow.operators.code.eval.code_quality_sample_evaluator import CodeQualitySampleEvaluator
@OPERATOR_REGISTRY.register()
class CodeQualityScoreFilter(OperatorABC):
"""
CodeQualityScoreFilter filters code samples based on LLM-generated quality scores
from CodeQualitySampleEvaluator. It evaluates code correctness, completeness, clarity,
best practices, and efficiency, then filters out samples below the specified threshold.
This filter uses evaluator scores to filter:
- Removes code with syntax errors or logical issues
- Removes incomplete or poorly structured code
- Removes code that doesn't follow best practices
- Keeps code with quality scores within specified range
"""
def __init__(self, llm_serving: LLMServingABC, min_score: int = 7, max_score: int = 10):
"""
Initializes the operator with LLM serving and evaluator.
"""
self.min_score = min_score
self.max_score = max_score
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} with min_score: {self.min_score} and max_score: {self.max_score}...")
self.scorer = CodeQualitySampleEvaluator(llm_serving)
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"基于LLM生成的代码质量分数过滤代码样本,评估正确性、完整性、清晰度、最佳实践和效率。\n\n"
"评估维度:\n"
"- 正确性:代码语法和逻辑是否正确\n"
"- 完整性:代码是否完整实现功能\n"
"- 清晰度:代码是否清晰易懂\n"
"- 最佳实践:是否遵循编程最佳实践\n"
"- 效率:代码执行效率如何\n\n"
"输入参数:\n"
"- input_code_key: 输入代码字段名\n"
"- input_instruction_key: 输入指令字段名\n"
"- output_score_key: 输出打分字段名 (默认: 'quality_score')\n"
"- output_feedback_key: 输出反馈字段名 (默认: 'quality_feedback')\n"
"- output_key: 输出过滤标签字段名 (默认: 'quality_score_filter_label')\n"
"- min_score: 最小质量分数阈值 (默认: 7)\n"
"- max_score: 最大质量分数阈值 (默认: 10)\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留质量分数在指定范围内的代码样本\n"
"- 返回包含质量分数标签字段名的列表"
)
else:
return (
"Filter code samples based on LLM-generated quality scores evaluating correctness, completeness, clarity, best practices, and efficiency.\n\n"
"Evaluation Dimensions:\n"
"- Correctness: syntax and logic accuracy\n"
"- Completeness: functional completeness\n"
"- Clarity: code readability and understandability\n"
"- Best Practices: adherence to programming standards\n"
"- Efficiency: execution performance\n\n"
"Input Parameters:\n"
"- input_code_key: Input code column name\n"
"- input_instruction_key: Input instruction column name\n"
"- output_score_key: Output score column name (default: 'quality_score')\n"
"- output_feedback_key: Output feedback column name (default: 'quality_feedback')\n"
"- output_key: Output filter label column name (default: 'quality_score_filter_label')\n"
"- min_score: Minimum quality score threshold (default: 7)\n"
"- max_score: Maximum quality score threshold (default: 10)\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only code samples with quality scores within specified range\n"
"- List containing quality score label field names"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure the required columns exist.
"""
required_keys = [self.input_code_key, self.input_instruction_key]
missing = [k for k in required_keys if k not in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s) for CodeQualityScoreFilter: {missing}")
def _apply_score_filter(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""
Apply the score-based filtering logic.
Args:
dataframe: Input DataFrame
Returns:
Filtered DataFrame
"""
# Filter based on score range
score_filter = (dataframe[self.output_score_key] >= self.min_score) & (dataframe[self.output_score_key] <= self.max_score)
# Also keep samples with failed parsing (score = 0)
nan_filter = dataframe[self.output_score_key] == 0
final_filter = score_filter | nan_filter
return dataframe[final_filter]
def run(self, storage: DataFlowStorage, input_instruction_key: str, input_code_key: str, output_score_key = "quality_score", output_feedback_key = "quality_feedback",output_key: str = 'quality_score_filter_label'):
self.input_code_key = input_code_key
self.input_instruction_key = input_instruction_key
self.output_score_key = output_score_key
self.output_key = output_key
self.output_feedback_key = output_feedback_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_code_key} and output_key: {self.output_key}...")
dataframe = storage.read("dataframe")
# Use existing quality_score if available, otherwise evaluate
if self.output_score_key not in dataframe.columns:
scores, feedbacks = self.scorer.eval(dataframe, self.input_instruction_key, self.input_code_key)
dataframe[self.output_score_key] = scores
dataframe[self.output_feedback_key] = feedbacks
# Apply filtering based on existing quality_score
results = np.ones(len(dataframe), dtype=int)
score_filter = (dataframe[self.output_score_key] >= self.min_score) & (dataframe[self.output_score_key] <= self.max_score)
nan_filter = dataframe[self.output_score_key] == 0 # Keep failed parsing samples
metric_filter = score_filter | nan_filter
results = results & metric_filter.astype(int)
self.logger.debug(f"Filtered by quality score, {np.sum(results)} data remained")
dataframe[self.output_key] = metric_filter.astype(int)
filtered_dataframe = dataframe[results == 1]
output_file = storage.write(filtered_dataframe)
self.logger.info(f"Filtering completed. Total records passing filter: {len(filtered_dataframe)}.")
return [self.output_key]
import pandas as pd
from typing import List, Literal
# Assuming these are the correct import paths for your framework
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
# TODO: remove and add to general filter
@OPERATOR_REGISTRY.register()
class CodeGenericScoreFilter(OperatorABC):
"""
CodeGenericScoreFilter is a generic score-based filtering operator that filters
datasets based on numerical score columns. It provides flexible comparison
methods to remove samples that don't meet specified threshold criteria.
This filter directly applies score-based filtering without using evaluator scores:
- Removes samples with scores below minimum threshold
- Removes samples with scores above maximum threshold
- Removes samples that don't meet specific score criteria
- Keeps samples that meet the specified threshold criteria
"""
def __init__(self, score_threshold: int = 8, filter_method: Literal["greater", "greater_equal", "less", "less_equal", "equal"] = "greater_equal"):
"""
Initializes the operator.
"""
self.logger = get_logger()
self.score_threshold = score_threshold
self.filter_method = filter_method
@staticmethod
def get_desc(lang: str = "en"):
"""
Provides a description of the operator's function and parameters.
"""
if lang == "zh":
return (
"基于数值分数列直接过滤数据集,提供灵活的阈值比较方法。\n\n"
"比较方法:\n"
"- greater_equal: 分数 >= 阈值\n"
"- greater: 分数 > 阈值\n"
"- less_equal: 分数 <= 阈值\n"
"- less: 分数 < 阈值\n"
"- equal: 分数 = 阈值\n\n"
"输入参数:\n"
"- input_key: 包含分数的字段名\n"
"- output_key: 输出标签字段名 (默认: 'generic_score_filter_label')\n"
"- score_threshold: 分数阈值 (默认: 8)\n"
"- filter_method: 比较方法 (默认: 'greater_equal')\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留符合分数条件的样本\n"
"- 返回包含输出标签字段名的列表"
)
else: # Default to English
return (
"Filter datasets based on numerical score columns with flexible threshold comparison methods.\n\n"
"Comparison Methods:\n"
"- greater_equal: score >= threshold\n"
"- greater: score > threshold\n"
"- less_equal: score <= threshold\n"
"- less: score < threshold\n"
"- equal: score = threshold\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the score\n"
"- output_key: Output label field name (default: 'generic_score_filter_label')\n"
"- score_threshold: Numerical threshold for filtering (default: 8)\n"
"- filter_method: Comparison method to use (default: 'greater_equal')\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only samples meeting score criteria\n"
"- List containing output label field name"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure the required score column exists.
"""
required_keys = [self.input_score_key]
missing = [k for k in required_keys if k not in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s) for ScoreFilter: {missing}")
# Also check if the column is numeric
if not pd.api.types.is_numeric_dtype(dataframe[self.input_score_key]):
raise TypeError(f"Column '{self.input_score_key}' for ScoreFilter must be of a numeric type.")
def run(
self,
storage: DataFlowStorage,
input_key: str,
output_key: str = "generic_score_filter_label"
) -> List[str]:
"""
Execute the filtering process.
Reads data from storage, applies the filter based on the score,
and writes the filtered data back to storage.
Args:
storage: Data storage object
input_key: Field name containing the score
output_key: Key name for output label
score_threshold: Numerical threshold for filtering
filter_method: Comparison method to use
Returns:
List[str]: List containing output key name
"""
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
# Store key for use in helper methods
self.input_score_key = input_key
# 1. Read data from the current step
dataframe = storage.read("dataframe")
if dataframe.empty:
self.logger.warning("Input dataframe is empty. Skipping.")
storage.write(dataframe)
return [self.output_key]
original_count = len(dataframe)
# 2. Validate the data
self._validate_dataframe(dataframe)
# 3. Apply the filter logic and add label
if self.filter_method == "greater_equal":
filter_mask = dataframe[self.input_score_key] >= self.score_threshold
elif self.filter_method == "greater":
filter_mask = dataframe[self.input_score_key] > self.score_threshold
elif self.filter_method == "less_equal":
filter_mask = dataframe[self.input_score_key] <= self.score_threshold
elif self.filter_method == "less":
filter_mask = dataframe[self.input_score_key] < self.score_threshold
elif self.filter_method == "equal":
filter_mask = dataframe[self.input_score_key] == self.score_threshold
else:
# This case should ideally not be hit due to Literal type hint, but is good for robustness
raise ValueError(f"Unsupported filter_method: '{filter_method}'")
dataframe[self.output_key] = filter_mask.astype(int)
filtered_df = dataframe[filter_mask]
filtered_count = len(filtered_df)
self.logger.info(f"Filtering completed. Total records passing filter: {filtered_count}.")
# 4. Write the results back to storage
storage.write(filtered_df)
# 5. Return output key
return [self.output_key]
\ No newline at end of file
import pandas as pd
import numpy as np
from typing import List
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.operators.code.eval.code_text_composition_sample_evaluator import CodeTextCompositionSampleEvaluator
@OPERATOR_REGISTRY.register()
class CodeTextCompositionFilter(OperatorABC):
"""
CodeTextCompositionFilter filters code samples based on character composition using
CodeTextCompositionSampleEvaluator scores. It removes binary files, encrypted content,
and other non-readable text.
"""
def __init__(self, min_score: float = 1.0, max_score: float = 1.0):
"""
Initialize the operator with evaluator and thresholds.
"""
self.min_score = min_score
self.max_score = max_score
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} with min_score: {self.min_score} and max_score: {self.max_score}...")
self.scorer = CodeTextCompositionSampleEvaluator()
@staticmethod
def get_desc(lang: str = "en"):
"""
Provide operator functionality description and parameter documentation.
"""
if lang == "zh":
return (
"基于CodeTextCompositionSampleEvaluator的得分过滤代码样本,移除二进制文件、加密内容和不可读文本。\n\n"
"评估指标:\n"
"- 字母字符比例:普通语言需要>=25%\n"
"- 字母数字字符比例:汇编语言需要>=25%\n"
"- 综合字符组成得分:0-1,1表示通过检查\n\n"
"输入参数:\n"
"- input_key: 输入字段名(需要包含'text'和'language'列)\n"
"- output_key: 输出标签字段名 (默认: 'text_composition_filter_label')\n"
"- min_score: 最小字符组成得分阈值 (默认: 1.0)\n"
"- max_score: 最大字符组成得分阈值 (默认: 1.0)\n\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留字符组成得分在指定范围内的代码样本\n"
"- 返回包含字符组成得分标签字段名的列表"
)
else:
return (
"Filter code samples using scores from CodeTextCompositionSampleEvaluator to remove binary files, encrypted content, and non-readable text.\n\n"
"Evaluation Metrics:\n"
"- Alphabetic character ratio: Normal languages require >=25%\n"
"- Alphanumeric character ratio: Assembly languages require >=25%\n"
"- Comprehensive composition score: 0-1, 1 means passes checks\n\n"
"Input Parameters:\n"
"- input_key: Input field name (requires 'text' and 'language' columns)\n"
"- output_key: Output label field name (default: 'text_composition_filter_label')\n"
"- min_score: Minimum composition score threshold (default: 1.0)\n"
"- max_score: Maximum composition score threshold (default: 1.0)\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only code samples with composition scores within specified range\n"
"- List containing composition score label field name"
)
def run(self, storage: DataFlowStorage, input_key: str, output_key: str = 'text_composition_filter_label'):
self.input_key = input_key
self.output_key = output_key
self.logger.info(f"Running {self.__class__.__name__} with input_key: {self.input_key} and output_key: {self.output_key}...")
dataframe = storage.read("dataframe")
scores = self.scorer.eval(dataframe, self.input_key)
# Add scores to dataframe
for idx, score_dict in enumerate(scores):
for key, value in score_dict.items():
dataframe.at[idx, key] = value
# Apply filtering based on CodeTextCompositionScore
results = np.ones(len(dataframe), dtype=int)
score_filter = (dataframe["CodeTextCompositionScore"] >= self.min_score) & (dataframe["CodeTextCompositionScore"] <= self.max_score)
nan_filter = np.isnan(dataframe["CodeTextCompositionScore"])
metric_filter = score_filter | nan_filter
results = results & metric_filter.astype(int)
self.logger.debug(f"Filtered by composition score, {np.sum(results)} data remained")
dataframe[f"{self.output_key}"] = metric_filter.astype(int)
filtered_dataframe = dataframe[results == 1]
storage.write(filtered_dataframe)
self.logger.info(f"Filtering completed. Total records passing filter: {len(filtered_dataframe)}.")
return [self.output_key]
\ No newline at end of file
import pandas as pd
from typing import List
# Assuming these are the correct import paths for your framework
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC # For type hinting if needed
from dataflow.core import LLMServingABC
from dataflow.prompts.code import CodeCodeToInstructionGeneratorPrompt, DiyCodePrompt
from typing import Union
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
@prompt_restrict(
CodeCodeToInstructionGeneratorPrompt,
DiyCodePrompt
)
@OPERATOR_REGISTRY.register()
class CodeCodeToInstructionGenerator(OperatorABC):
"""
CodeCodeToInstructionGenerator is an operator that uses an LLM to generate a human-readable
instruction based on a given code snippet. This is the first step in a
'self-instruct' style data synthesis pipeline for code.
"""
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeCodeToInstructionGeneratorPrompt, DiyCodePrompt, DIYPromptABC] = None):
"""
Initializes the operator with a language model serving endpoint.
"""
self.logger = get_logger()
self.llm_serving = llm_serving
# Initialize prompt template
if prompt_template is None:
prompt_template = CodeCodeToInstructionGeneratorPrompt()
elif isinstance(prompt_template, str):
prompt_template = DiyCodePrompt(prompt_template)
self.prompt_template = prompt_template
@staticmethod
def get_desc(lang: str = "en"):
"""
Provides a description of the operator's function and parameters.
"""
if lang == "zh":
return (
"该算子用于分析代码片段并反向生成可能产生该代码的人类指令。\n\n"
"输入参数:\n"
"- input_key: 包含原始代码片段的字段名 (默认: 'code')\n"
"输出参数:\n"
"- output_key: 用于存储生成指令的字段名 (默认: 'generated_instruction')\n"
)
else: # Default to English
return (
"This operator analyzes a code snippet and reverse-engineers a human instruction "
"that could have produced it.\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the raw code snippet (default: 'code')\n"
"Output Parameters:\n"
"- output_key: Field name to store the generated instruction (default: 'generated_instruction')\n"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure required columns exist and output columns don't.
"""
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _build_prompts(self, dataframe: pd.DataFrame) -> List[str]:
"""
Builds a list of prompts for the LLM based on the input code.
"""
prompts = [
self.prompt_template.build_prompt(code=row[self.input_key])
for _, row in dataframe.iterrows()
]
return prompts
def _parse_instruction(self, response: str) -> str:
"""
Parse the LLM's raw response to extract the clean instruction.
Args:
response: Raw response string from the LLM
Returns:
Clean instruction string without extra whitespace
"""
# The prompt is designed to make the LLM output only the instruction.
# This parsing step is mainly for cleaning up potential whitespace.
return response.strip()
def run(
self,
storage: DataFlowStorage,
input_key: str = "code",
output_key: str = "generated_instruction"
) -> List[str]:
"""
Executes the instruction synthesis process.
It reads data from storage, generates instructions for each code snippet,
and writes the updated data back to storage.
Returns:
A list containing the name of the newly created output column.
"""
self.input_key = input_key
self.output_key = output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._build_prompts(dataframe)
responses = self.llm_serving.generate_from_input(formatted_prompts)
instructions = [self._parse_instruction(r) for r in responses]
dataframe[self.output_key] = instructions
output_file = storage.write(dataframe)
self.logger.info(f"Generated instructions saved to {output_file}")
return [self.output_key]
\ No newline at end of file
import pandas as pd
from typing import List
import random
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.code import CodeInstructionGeneratePrompt, DiyCodePrompt
from typing import Union
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
@prompt_restrict(
CodeInstructionGeneratePrompt,
)
@OPERATOR_REGISTRY.register()
class CodeInstructionGenerator(OperatorABC):
"""
CodeInstructionGenerator is an operator that leverages a Large Language Model to generate
human-readable instructions based on few-shot examples sampled from a data pool. The operator
creates new instructions that are similar in difficulty and style to the provided examples.
This is a critical step in a 'self-instruct' style data synthesis pipeline, designed to expand
and enhance instruction datasets for programming tasks.
"""
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeInstructionGeneratePrompt, DIYPromptABC]=None, num_few_shot: int = 3, num_generate: int = 10):
"""
Initializes the operator with a language model serving endpoint.
Args:
llm_serving: LLM serving instance
prompt_template: Custom prompt template (optional)
num_few_shot: Number of few-shot examples to use (default: 3)
"""
self.logger = get_logger()
self.num_generate = num_generate
self.llm_serving = llm_serving
self.num_few_shot = num_few_shot
self.prompt_template = CodeInstructionGeneratePrompt()
@staticmethod
def get_desc(lang: str = "en"):
"""
Provides a description of the operator's function and parameters.
"""
if lang == "zh":
return (
"该算子用于生成新的指令,从数据池中随机抽取few-shot样本,生成类似难度的指令。\n\n"
"输入参数:\n"
"- input_key: 包含原始指令的字段名 (默认: 'prompt')\n"
"输出参数:\n"
"- output_key: 用于存储生成指令的字段名 (默认: 'generated_instruction')\n"
)
else:
return (
"This operator generates new instructions by sampling few-shot examples from the data pool "
"to create instructions of similar difficulty.\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the original instructions (default: 'prompt')\n"
"Output Parameters:\n"
"- output_key: Field name to store the generated instruction (default: 'generated_instruction')\n"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure required columns exist.
"""
required_keys = [self.input_key]
missing = [k for k in required_keys if k not in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if len(dataframe) < self.num_few_shot:
raise ValueError(f"数据池样本数量({len(dataframe)})少于few-shot数量({self.num_few_shot})")
def _sample_few_shot_examples(self, dataframe: pd.DataFrame) -> List[dict]:
"""
andomly sample few-shot examples from the data pool.
Args:
dataframe: data pool
Returns:
List of few-shot examples with instruction
"""
num_samples = min(self.num_few_shot, len(dataframe))
sampled_indices = random.sample(range(len(dataframe)), num_samples)
few_shot_examples = []
for idx in sampled_indices:
row = dataframe.iloc[idx]
instruction = row[self.input_key]
few_shot_examples.append({
'instruction': instruction,
})
return few_shot_examples
def _build_prompts(self, dataframe: pd.DataFrame, num_generate: int) -> List[str]:
"""
构建指定数量的prompt,每个prompt包含随机抽取的few-shot样本
Args:
dataframe: Data pool
num_generate: The number of prompts to be generated.
Returns:
List of prompts
"""
prompts = []
for i in range(num_generate):
few_shot_examples = self._sample_few_shot_examples(dataframe)
prompt = self.prompt_template.build_prompt(
few_shot_examples=few_shot_examples
)
prompts.append(prompt)
return prompts
def _parse_instruction(self, response: str) -> str:
"""
Parse the LLM's raw response to extract the clean instruction.
Args:
response: Raw response string from the LLM
Returns:
Clean instruction string without extra whitespace
"""
return response.strip()
def run(
self,
storage: DataFlowStorage,
input_key: str = "prompt",
output_key: str = "generated_instruction",
) -> List[str]:
"""
Executes the instruction synthesis process.
Reads data from the data pool, generates the specified number of new instructions, and saves them to a new DataFrame.
Args:
storage: DataFlow storage instance
input_key: Field name containing the original instructions
output_key: Field name to store generated instructions
Returns:
A list containing the name of the output column.
"""
self.input_key = input_key
self.output_key = output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
random.seed(42)
formatted_prompts = self._build_prompts(dataframe, self.num_generate)
responses = self.llm_serving.generate_from_input(formatted_prompts)
instructions = [self._parse_instruction(r) for r in responses]
new_dataframe = pd.DataFrame({
self.output_key: instructions
})
output_file = storage.write(new_dataframe)
self.logger.info(f"Generated {len(instructions)} new instructions with {self.num_few_shot} few-shot examples each")
self.logger.info(f"Results saved to {output_file}")
return [self.output_key]
\ No newline at end of file
import pandas as pd
from typing import List
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.code import CodeInstructionEnhancement, DiyCodePrompt
from typing import Union
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
@prompt_restrict(
CodeInstructionEnhancement,
DiyCodePrompt
)
@OPERATOR_REGISTRY.register()
class CodeEnhancementInstructionGenerator(OperatorABC):
"""
CodeEnhancementInstructionGenerator is an operator that uses an LLM to standardize
and normalize instructions into a consistent format for code generation tasks.
It rewrites original instructions into standardized English instruction + code block format.
"""
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeInstructionEnhancement, DiyCodePrompt, DIYPromptABC] = None):
"""
Initializes the operator with a language model serving endpoint.
"""
self.logger = get_logger()
self.llm_serving = llm_serving
# Initialize prompt template
if prompt_template is None:
prompt_template = CodeInstructionEnhancement()
elif isinstance(prompt_template, str):
prompt_template = DiyCodePrompt(prompt_template)
self.prompt_template = prompt_template
@staticmethod
def get_desc(lang: str = "en"):
"""
Provides a description of the operator's function and parameters.
"""
if lang == "zh":
return (
"该算子用于增强人类指令,将不同输出格式的任务统一为生成完整函数。\n\n"
"输入参数:\n"
"- input_key: 包含原始代码片段的字段名 (默认: 'code')\n"
"输出参数:\n"
"- output_key: 用于存储生成指令的字段名 (默认: 'generated_instruction')\n"
)
else:
return (
"This operator enhances human instructions by unifying tasks with different "
"output formats into complete function generation tasks.\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the original instruction messages (default: 'messages')\n"
"Output Parameters:\n"
"- output_key: Field name to store the enhanced instruction (default: 'generated_instruction')\n"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure required columns exist and output columns don't conflict.
"""
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _build_prompts(self, dataframe: pd.DataFrame) -> List[str]:
def get_human_instruction(messages):
"""Extract human instruction from message list."""
for item in messages:
if item.get('role') == 'HUMAN':
return item.get('content', '')
return ''
return [
self.prompt_template.build_prompt(instruction=get_human_instruction(row[self.input_key]))
for _, row in dataframe.iterrows()
]
def _parse_instruction(self, response: str) -> str:
"""
Parse the LLM's raw response to extract the enhanced instruction.
Args:
response: Raw response string from the LLM
Returns:
Clean instruction string without extra whitespace
"""
return response.strip()
def run(
self,
storage: DataFlowStorage,
input_key: str = "messages",
output_key: str = "generated_instruction"
) -> List[str]:
"""
Executes the instruction synthesis process.
Reads data from storage, enhances instructions for each message,
and writes the updated data back to storage.
Returns:
A list containing the name of the newly created output column.
"""
self.input_key = input_key
self.output_key = output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._build_prompts(dataframe)
responses = self.llm_serving.generate_from_input(formatted_prompts)
instructions = [self._parse_instruction(r) for r in responses]
dataframe[self.output_key] = instructions
output_file = storage.write(dataframe)
self.logger.info(f"Generated instructions saved to {output_file}")
return [self.output_key]
\ No newline at end of file
import pandas as pd
import re
from typing import List
# Assuming these are the correct import paths for your framework
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.code import CodeInstructionToCodeGeneratorPrompt, DiyCodePrompt
from typing import Union
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
@prompt_restrict(
CodeInstructionToCodeGeneratorPrompt,
DiyCodePrompt
)
@OPERATOR_REGISTRY.register()
class CodeInstructionToCodeGenerator(OperatorABC):
"""
CodeInstructionToCodeGenerator is an operator that takes a natural language instruction and
uses an LLM to generate a corresponding code snippet. This is the second step
in a 'self-instruct' style data synthesis pipeline for code.
"""
def __init__(self, llm_serving: LLMServingABC, prompt_template: Union[CodeInstructionToCodeGeneratorPrompt, DiyCodePrompt, DIYPromptABC] = None):
"""
Initializes the operator with a language model serving endpoint.
"""
self.logger = get_logger()
self.llm_serving = llm_serving
# Initialize prompt template
if prompt_template is None:
prompt_template = CodeInstructionToCodeGeneratorPrompt()
elif isinstance(prompt_template, str):
prompt_template = DiyCodePrompt(prompt_template)
self.prompt_template = prompt_template
@staticmethod
def get_desc(lang: str = "en"):
"""
Provides a description of the operator's function and parameters.
"""
if lang == "zh":
return (
"该算子根据给定的人类指令生成相应的代码片段。\n\n"
"输入参数:\n"
"- input_key: 包含人类指令的字段名 (默认: 'instruction')\n"
"输出参数:\n"
"- output_key: 用于存储生成代码的字段名 (默认: 'generated_code')\n"
)
else: # Default to English
return (
"This operator generates a code snippet based on a given natural language instruction.\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the human instruction (default: 'instruction')\n"
"Output Parameters:\n"
"- output_key: Field name to store the generated code (default: 'generated_code')\n"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
"""
Validates the DataFrame to ensure required columns exist and output columns don't.
"""
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _build_prompts(self, dataframe: pd.DataFrame) -> List[str]:
"""
Builds a list of prompts for the LLM based on the input instructions.
"""
prompts = [
self.prompt_template.build_prompt(instruction=row[self.input_key])
for _, row in dataframe.iterrows()
]
return prompts
def _parse_code(self, response: str) -> str:
"""
Parse the LLM's raw response to extract only the code.
Removes potential markdown code blocks and leading/trailing whitespace.
Args:
response: Raw response string from the LLM
Returns:
Clean code string without markdown formatting
"""
# Use regex to find content within ```python ... ``` or ``` ... ```
code_block_match = re.search(r"```(?:python\n)?(.*)```", response, re.DOTALL)
if code_block_match:
# If a markdown block is found, extract its content
return code_block_match.group(1).strip()
else:
# Otherwise, assume the whole response is code and just strip it
return response.strip()
def run(
self,
storage: DataFlowStorage,
input_key: str = "instruction",
output_key: str = "generated_code"
) -> List[str]:
"""
Executes the code generation process.
It reads data from storage, generates code for each instruction,
and writes the updated data back to storage.
Returns:
A list containing the name of the newly created output column.
"""
self.input_key = input_key
self.output_key = output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._build_prompts(dataframe)
responses = self.llm_serving.generate_from_input(formatted_prompts)
codes = [self._parse_code(r) for r in responses]
dataframe[self.output_key] = codes
output_file = storage.write(dataframe)
self.logger.info(f"Generated code saved to {output_file}")
return [self.output_key]
\ No newline at end of file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .generate.func_call_generators import (
ScenarioExtractGenerator,
ScenarioExpandGenerator,
AtomTaskGenerator,
SequentialTaskGenerator,
ParaSeqTaskGenerator,
FunctionGenerator,
MultiTurnConversationGenerator,
)
from .generate.consistent_chat_generator import ConsistentChatGenerator
from .eval.func_call_conversation_sample_evaluator import FuncCallConversationSampleEvaluator
from .filter.composition_task_filter import CompositionTaskFilter
else:
import sys
from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking
cur_path = "dataflow/operators/conversations/"
# _import_structure = {
# "ScenarioExtractor": (cur_path + "func_call_operators.py", "ScenarioExtractor"),
# "ScenarioExpander": (cur_path + "func_call_operators.py", "ScenarioExpander"),
# "AtomTaskGenerator": (cur_path + "func_call_operators.py", "AtomTaskGenerator"),
# "SequentialTaskGenerator": (cur_path + "func_call_operators.py", "SequentialTaskGenerator"),
# "ParaSeqTaskGenerator": (cur_path + "func_call_operators.py", "ParaSeqTaskGenerator"),
# "CompositionTaskFilter": (cur_path + "func_call_operators.py", "CompositionTaskFilter"),
# "FunctionGenerator": (cur_path + "func_call_operators.py", "FunctionGenerator"),
# "MultiTurnDialogueGenerator": (cur_path + "func_call_operators.py", "MultiTurnDialogueGenerator"),
# "ConsistentChatGenerator": (cur_path + "consistent_chat.py", "ConsistentChatGenerator")
# }
_import_structure = generate_import_structure_from_type_checking(__file__, cur_path)
sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/conversations/", _import_structure)
\ No newline at end of file
import re
import pandas as pd
import numpy as np
from tqdm import tqdm
from dataflow.core import OperatorABC, LLMServingABC
from dataflow.core.prompt import prompt_restrict
from dataflow.utils.storage import DataFlowStorage
from dataflow.prompts.func_call import ConversationEvalPrompt
from dataflow.logger import get_logger
from dataflow.utils.registry import OPERATOR_REGISTRY
@prompt_restrict(
ConversationEvalPrompt
)
@OPERATOR_REGISTRY.register()
class FuncCallConversationSampleEvaluator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.llm_serving = llm_serving
self.prompt = ConversationEvalPrompt()
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"对对话样本进行打分评估:使用 LLM 服务根据预设评分提示词对每条对话进行评分,并将结果写回数据流。\n"
"输入参数:\n"
"- llm_serving:LLM 服务对象,需实现 LLMServingABC 接口\n"
"- input_conversation_key:DataFrame 中对话内容字段名,默认 'conversations'\n"
"- output_score_key:评分结果输出字段名,默认 'score'\n"
"处理流程:\n"
"- 读取存储中的 DataFrame\n"
"- 将每条对话重组为评分提示词并调用 LLM 生成评分(JSON)\n"
"- 解析 JSON,提取 'score' 字段写入 DataFrame;解析失败则回退为 0\n"
"输出参数:\n"
"- 包含评分结果列的 DataFrame\n"
"- 包含输出字段名的列表(仅 'score' 或自定义的输出列名)"
)
elif lang == "en":
return (
"Evaluate conversation samples with an LLM-based scorer: the operator formats each "
"conversation into a scoring prompt, calls the LLM, parses the JSON response, and writes the score back.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC\n"
"- input_conversation_key: column name for conversations in the DataFrame, default 'conversations'\n"
"- output_score_key: column name for the score output, default 'score'\n"
"Process:\n"
"- Read the DataFrame from storage\n"
"- Reformat each conversation into a scoring prompt and call the LLM (expects JSON)\n"
"- Parse the JSON to extract 'score'; fallback to 0 on parse errors\n"
"Output:\n"
"- DataFrame with a score column added\n"
"- A list containing the output field name (e.g., 'score')"
)
else:
return "Evaluate conversation samples with an LLM-based scorer and write the parsed 'score' back to the DataFrame."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = []
for conversation in tqdm(dataframe[self.input_conversation_key], desc="Reformatting prompts..."):
formatted_prompts.append(self.prompt.build_prompt(conversation=conversation))
return formatted_prompts
def clean_json_block(self, s: str) -> str:
s = s.strip()
if s.startswith("```"):
# 去掉首尾 ```json 或 ``` 包裹
s = s.strip("`")
s = s.replace("json\n", "", 1) # 去掉开头的 json\n
if s.endswith("```"):
s = s[:-3]
return s.strip()
def json_validate(self, llm_outputs):
import json
scores = []
for item in llm_outputs:
score = 0
try:
data = json.loads(self.clean_json_block(item))
score = data['score']
except Exception as e:
self.logger.debug(f"json loading error in item {item}")
scores.append(score)
return scores
def run(self, storage: DataFlowStorage, input_conversation_key: str = "conversations", output_score_key = "score"):
self.input_conversation_key = input_conversation_key
self.output_score_key = output_score_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
dataframe[self.output_score_key] = self.json_validate(llm_outputs)
storage.write(dataframe)
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_score_key]
import re
import pandas as pd
import numpy as np
from tqdm import tqdm
from dataflow.core import OperatorABC, LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.prompts.func_call import CompositionTaskFilterPrompt
from dataflow.logger import get_logger
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.core.prompt import prompt_restrict
@prompt_restrict(
CompositionTaskFilterPrompt
)
@OPERATOR_REGISTRY.register()
class CompositionTaskFilter(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = CompositionTaskFilterPrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"根据组合任务及其子任务,使用LLM服务判断组合任务是否具备可行性与完备性,从而进行可运行任务的筛选。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_composition_task_key:组合任务字段名\n"
"- input_sub_tasks_keys:子任务字段名列表(如原子任务、并行任务、后继任务等)\n"
"- output_key:可运行标签的输出字段名,默认'runable_label'\n"
"输出参数:\n"
"- 仅包含可运行组合任务的数据DataFrame\n"
"- 包含输出字段名的列表(可运行标签字段)"
)
elif lang == "en":
return (
"Evaluate the feasibility and completeness of a composition task based on its sub-tasks using an LLM service, and filter out unexecutable tasks.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_composition_task_key: Field name for the composition task\n"
"- input_sub_tasks_keys: List of field names for sub-tasks (e.g., atomic, parallel, subsequent tasks)\n"
"- output_key: Field name for the executability label, default 'runable_label'\n"
"Output Parameters:\n"
"- DataFrame containing only executable composition tasks\n"
"- List containing the output field name (executability label)"
)
else:
return "Filter composition tasks for feasibility and completeness using LLM service."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = []
for task, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_keys].to_dict(orient='records')), desc="Reformatting prompts..."):
formatted_prompts.append(self.prompt.build_prompt(task=task, sub_tasks=sub_tasks))
# formatted_prompts = [self.prompt.filter_composition_task(task=item, sub_tasks=sub_tasks) for item, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_key]), desc=f"Reformatting prompts...")]
return formatted_prompts
def run(self, storage: DataFlowStorage, input_composition_task_key: str, input_sub_tasks_keys: list[str], output_key: str = "runable_label"):
self.input_composition_task_key = input_composition_task_key
self.input_sub_tasks_keys = input_sub_tasks_keys
self.output_key = output_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
self.logger.debug(f"One of formatted prompts: {llm_inputs[0]}")
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
self.logger.debug(f"One of LLM outputs: {llm_outputs[0]}")
labels = []
for item in llm_outputs:
match = re.search(r"<ans>(yes|no)</ans>", item.strip(), re.IGNORECASE)
if match:
labels.append(1 if match.group(1).lower() == "yes" else 0)
else:
labels.append(0)
dataframe[self.output_key] = labels
dataframe = dataframe[dataframe[self.output_key] > 0]
storage.write(dataframe)
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_key]
import random
import json
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
from dataflow.core import LLMServingABC
from dataflow.prompts.general_text import ConsistentQueryPrompt, ConsistentResponsePrompt
from dataflow.core.prompt import prompt_restrict
@prompt_restrict(
ConsistentQueryPrompt,
ConsistentResponsePrompt
)
@OPERATOR_REGISTRY.register()
class ConsistentChatGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None, num_dialogs_per_intent = 20, num_turns_per_dialog = 6, temperature = 0.9):
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.llm_serving = llm_serving
self.num_dialogs_per_intent = num_dialogs_per_intent # Based on the topic_dict in the existing prompt, it is recommended to set the value to below 1000 (which can generate 9000 conversation data). Otherwise, it is recommended to add more topic_dict in dataflow.prompts.general_text.ConsistentChatPrompt to increase data richness
self.num_turns_per_dialog = num_turns_per_dialog
self.temperature = temperature
self.query_prompt = ConsistentQueryPrompt()
self.response_prompt = ConsistentResponsePrompt()
self.logger.info(f'{self.__class__.__name__} initialized.')
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"根据预置主题和人类意图,两阶段从0合成多轮对话格式数据(合成数量大于9000时建议增加标签数量)。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- num_dialogs_per_intent:每个意图生成的对话数量,默认20\n"
"- num_turns_per_dialog:每个对话的轮次数量,默认6\n"
"- temperature:生成温度,控制输出随机性,默认0.9\n"
"输出参数:\n"
"- 包含category和conversation字段的DataFrame,其中conversation为多轮对话列表"
)
elif lang == "en":
return (
"Two-stage generation of multi-turn dialogue data from scratch based on predefined topics and human intents (for over 9000 samples, consider increasing the number of tags).\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- num_dialogs_per_intent: Number of dialogs generated per intent, default 20\n"
"- num_turns_per_dialog: Number of turns per dialog, default 6\n"
"- temperature: Sampling temperature for generation, default 0.9\n"
"Output Parameters:\n"
"- DataFrame containing 'category' and 'conversation' fields, where conversation is a list of multi-turn dialogues"
)
else:
return "Two-stage generation of multi-turn dialogue data from scratch based on predefined topics and human intents."
def run(self, storage: DataFlowStorage):
# Step 1: Generate all queries using LLM
all_query_prompts = self.query_prompt.build_prompt(num_dialogs_per_intent=self.num_dialogs_per_intent)
# Step 2: Generate queries by calling llm_serving once
self.logger.info("Generating queries...")
queries_list = self.llm_serving.generate_from_input(user_inputs=all_query_prompts)
valid_queries = []
cnt = 0
for queries_str in queries_list:
try:
if not isinstance(queries_str, str):
raise ValueError("Invalid response type")
clean_queries_str = queries_str.replace("```json", "").replace("```", "").strip()
queries = json.loads(clean_queries_str)
valid_queries.append(queries)
except (json.JSONDecodeError, ValueError) as e:
cnt += 1
self.logger.debug(f'Json parse failed counts: {cnt} (Model generation error)')
continue
all_response_prompts = []
for queries in valid_queries:
category = queries.get("category")
turns = queries.get("turns")
all_response_prompts.append(self.response_prompt.build_prompt(topic=category, queries=turns))
self.logger.info("Generating responses...")
responses_list = self.llm_serving.generate_from_input(user_inputs=all_response_prompts)
final_queries = []
final_responses = []
cnt = 0
for query, responses_str in zip(valid_queries, responses_list):
try:
if not isinstance(responses_str, str):
raise ValueError("Invalid response type")
clean_responses_str = responses_str.replace("```json", "").replace("```", "").strip()
responses = json.loads(clean_responses_str)
final_queries.append(query)
final_responses.append(responses)
except (json.JSONDecodeError, ValueError) as e:
cnt += 1
self.logger.debug(f'Json parse failed counts: {cnt} (Model generation error): {str(e)}')
continue
formatted_data = []
for query_data, response_data in zip(final_queries, final_responses):
if isinstance(response_data, dict):
response_data = response_data.get('responses', [])
try:
category = query_data['category']
turns = query_data['turns']
# Ensure the number of turns matches the number of responses
num_user_turns = len(turns)
num_assistant_responses = len(response_data)
if num_user_turns > num_assistant_responses:
turns = turns[:num_assistant_responses]
conversation = []
for i in range(len(turns)):
conversation.append({"role": "user", "value": turns[i]})
if i < len(response_data):
conversation.append({"role": "assistant", "value": response_data[i]['response']})
# Ensure conversation does not end with a user message
if conversation and conversation[-1]["role"] == "user":
conversation.pop()
if not conversation:
continue
formatted_data.append({
"category": category,
"conversation": conversation
})
except Exception as e:
self.logger.debug(f"Error processing category '{query_data.get('category', 'Unknown')}': {e}")
continue
self.logger.info(f'Number of synthesized dialogues: {len(formatted_data)}')
df = pd.DataFrame(formatted_data)
storage.write(df)
self.logger.info(f'Number of synthesized dialogues: {len(df)} written to storage as DataFrame')
return df
import re
import pandas as pd
import numpy as np
from tqdm import tqdm
from dataflow.core import OperatorABC, LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.prompts.func_call import (
ExtractScenarioPrompt,
ExpandScenarioPrompt,
FuncAtomicTaskGeneratePrompt,
SequentialTaskGeneratePrompt,
ParathenSeqTaskGeneratePrompt,
FuncGeneratePrompt,
ConversationUserPrompt,
ConversationAssistantPrompt,
ConversationToolPrompt,
)
from dataflow.logger import get_logger
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.core.prompt import prompt_restrict
@prompt_restrict(
ExtractScenarioPrompt
)
@OPERATOR_REGISTRY.register()
class ScenarioExtractGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = ExtractScenarioPrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"从对话内容中提取场景信息,使用LLM服务分析对话并生成场景描述。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_chat_key:对话内容字段名\n"
"- output_key:输出场景字段名,默认'scenario'\n"
"输出参数:\n"
"- 包含提取场景信息的DataFrame\n"
"- 包含输出字段名的列表"
)
elif lang == "en":
return (
"Extract scenario information from conversation content using LLM service to analyze dialogues and generate scenario descriptions.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_chat_key: Field name for conversation content\n"
"- output_key: Field name for output scenario, default 'scenario'\n"
"Output Parameters:\n"
"- DataFrame containing extracted scenario information\n"
"- List containing output field name"
)
else:
return "Extract scenario information from conversation content using LLM service."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = [self.prompt.build_prompt(conversation=item) for item in tqdm(dataframe[self.input_chat_key], desc=f"Reformatting prompts...")]
return formatted_prompts
def run(self, storage: DataFlowStorage, input_chat_key: str, output_key: str = "scenario"):
self.input_chat_key = input_chat_key
self.output_key = output_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
dataframe[self.output_key] = llm_outputs
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_key]
@prompt_restrict(
ExpandScenarioPrompt
)
@OPERATOR_REGISTRY.register()
class ScenarioExpandGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = ExpandScenarioPrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"基于原始场景生成新的替代场景,使用LLM服务重写或改写原有场景内容。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_scenario_key:原始场景字段名\n"
"- output_key:生成的新场景字段名,默认'modified_scenario'\n"
"输出参数:\n"
"- 包含生成新场景的DataFrame\n"
"- 包含输出字段名的列表"
)
elif lang == "en":
return (
"Generate new or alternative scenarios based on the original scenario using LLM service. The original content is rewritten or reimagined to create a different version.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_scenario_key: Field name for the original scenario\n"
"- output_key: Field name for the new scenario, default 'modified_scenario'\n"
"Output Parameters:\n"
"- DataFrame containing newly generated scenarios\n"
"- List containing output field name"
)
else:
return "Generate new scenarios using LLM service based on original inputs."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = [self.prompt.build_prompt(scenario=item) for item in tqdm(dataframe[self.input_scenario_key], desc=f"Reformatting prompts...")]
return formatted_prompts
def run(self, storage: DataFlowStorage, input_scenario_key: str, output_key: str = "modified_scenario"):
self.input_scenario_key = input_scenario_key
self.output_key = output_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
dataframe[self.output_key] = llm_outputs
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_key]
@prompt_restrict(
FuncAtomicTaskGeneratePrompt
)
@OPERATOR_REGISTRY.register()
class AtomTaskGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = FuncAtomicTaskGeneratePrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"根据输入的场景信息,使用LLM服务生成对应的原子任务。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_scenario_key:场景字段名\n"
"- output_key:原子任务的输出字段名,默认'atom_task'\n"
"输出参数:\n"
"- 包含原子任务的DataFrame\n"
"- 包含输出字段名的列表"
)
elif lang == "en":
return (
"Generate atomic task based on the input scenario using an LLM service.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_scenario_key: Field name for the scenario\n"
"- output_key: Field name for the atomic task output, default 'atom_task'\n"
"Output Parameters:\n"
"- DataFrame containing the atomic tasks\n"
"- List containing output field name"
)
else:
return "Generate atomic tasks from scenario using LLM service."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = [self.prompt.build_prompt(scenario=item) for item in tqdm(dataframe[self.input_scenario_key], desc=f"Reformatting prompts...")]
return formatted_prompts
def run(self, storage: DataFlowStorage, input_scenario_key: str, output_key: str = "atom_task"):
self.input_scenario_key = input_scenario_key
self.output_key = output_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
dataframe[self.output_key] = llm_outputs
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_key]
@prompt_restrict(
SequentialTaskGeneratePrompt
)
@OPERATOR_REGISTRY.register()
class SequentialTaskGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = SequentialTaskGeneratePrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = [self.prompt.build_prompt(task=item) for item in tqdm(dataframe[self.input_task_key], desc=f"Reformatting prompts...")]
return formatted_prompts
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"根据输入的原子任务,使用LLM服务生成该任务的后继任务和两者的组合任务。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_task_key:原子任务字段名\n"
"- output_subsequent_task_key:后继任务输出字段名,默认'subsequent_task'\n"
"- output_composition_task_key:组合任务输出字段名,默认'composition_task'\n"
"输出参数:\n"
"- 包含后继任务和组合任务的DataFrame\n"
"- 输出字段名的列表(后继任务字段和组合任务字段)"
)
elif lang == "en":
return (
"Generate the subsequent task and a composition task based on the input atomic task using an LLM service.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_task_key: Field name for the atomic task\n"
"- output_subsequent_task_key: Field name for the subsequent task output, default 'subsequent_task'\n"
"- output_composition_task_key: Field name for the composition task output, default 'composition_task'\n"
"Output Parameters:\n"
"- DataFrame containing both subsequent and composition tasks\n"
"- List containing the names of the output fields"
)
else:
return "Generate subsequent and composition tasks from atomic task using LLM service."
def run(self, storage: DataFlowStorage, input_task_key: str, output_subsequent_task_key: str = "subsequent_task", output_composition_task_key: str = "composition_task"):
self.input_task_key = input_task_key
self.output_subsequent_task_key = output_subsequent_task_key
self.output_composition_task_key = output_composition_task_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
subsequent_tasks, composition_tasks = [], []
for item in llm_outputs:
# 正则表达式提取
match_subsequent = re.search(r"### Subsequent Task: (.*?)\n", item)
match_composition = re.search(r"### Composition Task: (.*?)$", item)
if match_subsequent:
subsequent_task = match_subsequent.group(1)
else:
subsequent_task = None
if match_composition:
composition_task = match_composition.group(1)
else:
composition_task = None
subsequent_tasks.append(subsequent_task)
composition_tasks.append(composition_task)
dataframe[self.output_subsequent_task_key] = subsequent_tasks
dataframe[self.output_composition_task_key] = composition_tasks
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_subsequent_task_key, output_composition_task_key]
@prompt_restrict(
ParathenSeqTaskGeneratePrompt
)
@OPERATOR_REGISTRY.register()
class ParaSeqTaskGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = ParathenSeqTaskGeneratePrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"基于原子任务,使用LLM服务生成三个任务类型:并行任务、后继任务以及这三者的组合任务。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_task_key:原子任务字段名\n"
"- output_parallel_task_key:并行任务输出字段名,默认'parallel_task'\n"
"- output_subsequent_task_key:后继任务输出字段名,默认'subsequent_task'\n"
"- output_composition_task_key:组合任务输出字段名,默认'composition_task'\n"
"输出参数:\n"
"- 包含并行任务、后继任务与组合任务的DataFrame\n"
"- 输出字段名列表(并行任务、后继任务、组合任务)"
)
elif lang == "en":
return (
"Based on a given atomic task, this operator uses an LLM service to generate three task types: "
"a parallel task, a subsequent task, and a composition task combining them.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_task_key: Field name for the atomic task\n"
"- output_parallel_task_key: Field name for the parallel task, default 'parallel_task'\n"
"- output_subsequent_task_key: Field name for the subsequent task, default 'subsequent_task'\n"
"- output_composition_task_key: Field name for the composition task, default 'composition_task'\n"
"Output Parameters:\n"
"- DataFrame containing parallel, subsequent, and composition tasks\n"
"- List containing the output field names"
)
else:
return "Generate parallel, subsequent, and composition tasks based on an atomic task using LLM service."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = [self.prompt.build_prompt(task=item) for item in tqdm(dataframe[self.input_task_key], desc=f"Reformatting prompts...")]
return formatted_prompts
def run(self, storage: DataFlowStorage, input_task_key: str, output_parallel_task_key: str = "parallel_task", output_subsequent_task_key: str = "subsequent_task", output_composition_task_key: str = "composition_task"):
self.input_task_key = input_task_key
self.output_parallel_task_key = output_parallel_task_key
self.output_subsequent_task_key = output_subsequent_task_key
self.output_composition_task_key = output_composition_task_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
parallel_tasks, subsequent_tasks, composition_tasks = [], [], []
for item in llm_outputs:
# 正则表达式提取
match_parallel = re.search(r"### Parallel Task: (.*?)\n", item)
match_subsequent = re.search(r"### Subsequent Task: (.*?)\n", item)
match_composition = re.search(r"### Composition Task: (.*?)$", item)
if match_parallel:
parallel_task = match_parallel.group(1)
else:
parallel_tasks = None
if match_subsequent:
subsequent_task = match_subsequent.group(1)
else:
subsequent_task = None
if match_composition:
composition_task = match_composition.group(1)
else:
composition_task = None
parallel_tasks.append(parallel_task)
subsequent_tasks.append(subsequent_task)
composition_tasks.append(composition_task)
dataframe[self.output_parallel_task_key] = parallel_tasks
dataframe[self.output_subsequent_task_key] = subsequent_tasks
dataframe[self.output_composition_task_key] = composition_tasks
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_parallel_task_key, self.output_subsequent_task_key, output_composition_task_key]
@prompt_restrict(
FuncGeneratePrompt
)
@OPERATOR_REGISTRY.register()
class FunctionGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.logger = get_logger()
self.prompt = FuncGeneratePrompt()
self.llm_serving = llm_serving
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"基于组合任务及其相关子任务,使用LLM服务生成对应的函数列表。"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_composition_task_key:组合任务字段名\n"
"- input_sub_tasks_keys:子任务字段名列表(如原子任务、并行任务、后继任务等)\n"
"- output_key:函数列表输出字段名,默认'functions'\n"
"输出参数:\n"
"- 包含函数定义或函数列表的DataFrame\n"
"- 输出字段名的列表(函数列表字段)"
)
elif lang == "en":
return (
"Generate a list of functions based on a composition task and its associated sub-tasks using an LLM service. "
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_composition_task_key: Field name for the composition task\n"
"- input_sub_tasks_keys: List of field names for sub-tasks (e.g., atomic, parallel, subsequent tasks)\n"
"- output_key: Field name for the generated functions, default 'functions'\n"
"Output Parameters:\n"
"- DataFrame containing the generated functions or function list\n"
"- List containing the output field name"
)
else:
return "Generate functions from composition and sub-tasks using LLM service."
def _reformat_prompt(self, dataframe: pd.DataFrame):
formatted_prompts = []
for task, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_keys].to_dict(orient='records')), desc="Reformatting prompts..."):
formatted_prompts.append(self.prompt.build_prompt(task=task, sub_tasks=sub_tasks))
# formatted_prompts = [self.prompt.filter_composition_task(task=item, sub_tasks=sub_tasks) for item, sub_tasks in tqdm(zip(dataframe[self.input_composition_task_key], dataframe[self.input_sub_tasks_key]), desc=f"Reformatting prompts...")]
return formatted_prompts
def run(self, storage: DataFlowStorage, input_composition_task_key: str, input_sub_tasks_keys: list[str], output_key: str = "functions"):
self.input_composition_task_key = input_composition_task_key
self.input_sub_tasks_keys = input_sub_tasks_keys
self.output_key = output_key
dataframe = storage.read("dataframe")
llm_inputs = self._reformat_prompt(dataframe)
# self.logger.info(f"One of formatted prompts: {llm_inputs[0]}")
llm_outputs = self.llm_serving.generate_from_input(llm_inputs)
# self.logger.info(f"One of LLM outputs: {llm_outputs[0]}")
dataframe[self.output_key] = llm_outputs
storage.write(dataframe)
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.output_key]
@prompt_restrict(
ConversationUserPrompt,
ConversationAssistantPrompt,
ConversationToolPrompt
)
@OPERATOR_REGISTRY.register()
class MultiTurnConversationGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC):
self.llm_serving = llm_serving
self.user_prompt = ConversationUserPrompt()
self.assistant_prompt = ConversationAssistantPrompt()
self.tool_prompt = ConversationToolPrompt()
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"根据组合任务及其子任务函数,使用LLM服务模拟多轮对话过程,"
"由User、Assistant和Tool三个Agent协同生成完整的对话数据。\n"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- input_task_key:任务字段名(组合任务)\n"
"- input_sub_tasks_keys:子任务字段名列表\n"
"- input_functions_key:子任务函数字段名\n"
"- output_conversations_key:输出对话字段名,默认'conversations'\n"
"输出参数:\n"
"- 包含已完成的多轮对话记录的DataFrame\n"
"- 输出字段名(对话字段名)"
)
elif lang == "en":
return (
"Simulate multi-turn conversations based on composition tasks and their sub-task functions using an LLM service.\n"
"The process involves three agents: User, Assistant, and Tool, interacting to complete the conversation.\n"
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- input_task_key: Field name for the main task (composition task)\n"
"- input_sub_tasks_keys: List of field names for sub-tasks\n"
"- input_functions_key: Field name containing sub-task functions\n"
"- output_conversations_key: Field name for storing the generated conversations, default 'conversations'\n"
"Output Parameters:\n"
"- DataFrame containing multi-turn conversations with completed sessions\n"
"- Output field name for the conversation content"
)
else:
return "Generate multi-turn dialogues from composition tasks and functions using user, assistant, and tool agents."
def _reformat_user_agent_prompt(self, dataframe: pd.DataFrame):
user_agent_prompts = []
for item in tqdm(dataframe[self.input_task_key], desc="Reformatting prompts..."):
user_agent_prompts.append(self.user_prompt.build_prompt(task=item))
return user_agent_prompts
def _reformat_assistant_agent_prompt(self, user_agent_prompts: list[str], dataframe: pd.DataFrame):
assistant_agent_sys_prompts = []
for sub_tasks, functions in zip(dataframe[self.input_sub_tasks_keys].to_dict(orient='records'), dataframe[self.input_functions_key]):
assistant_agent_sys_prompts.append(self.assistant_prompt.build_prompt(sub_task=sub_tasks, sub_task_func=functions))
assistant_agent_user_inputs = user_agent_prompts
inputs = [[{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_input}] for sys_prompt, user_input in zip(assistant_agent_sys_prompts, assistant_agent_user_inputs)]
return inputs
def _reformat_tool_agent_prompt(self, func_calls: list[str]):
tool_agent_prompts = []
for func_call in func_calls:
tool_agent_prompts.append(self.tool_prompt.build_prompt(function=func_call))
return tool_agent_prompts
def run(self, storage: DataFlowStorage, input_task_key: str, input_sub_tasks_keys: list[str], input_functions_key: list[str], output_conversations_key: str = "conversations"):
self.input_task_key = input_task_key
self.input_sub_tasks_keys = input_sub_tasks_keys
self.input_functions_key = input_functions_key
self.output_user_agent_response_key = "user_response"
self.output_key = output_conversations_key
dataframe = storage.read("dataframe")
user_agent_prompts = self._reformat_user_agent_prompt(dataframe)
user_agent_responses = self.llm_serving.generate_from_input(user_agent_prompts)
dataframe[self.output_user_agent_response_key] = user_agent_responses
turns = 0
completed_label = [0] * len(dataframe)
valid_label = [0] * len(dataframe)
cur_conversations = self._reformat_assistant_agent_prompt(user_agent_responses, dataframe)
while True:
assistant_agent_inputs = cur_conversations
not_completed_idxs = np.where(np.array(completed_label) == 0)[0]
valid_idxs = np.where(np.array(valid_label) == 0)[0]
cur_chatting_idxs = np.intersect1d(not_completed_idxs, valid_idxs)
cur_chatting_conversations = [assistant_agent_inputs[idx] for idx in cur_chatting_idxs]
assistant_agent_outputs = self.llm_serving.generate_from_conversations(cur_chatting_conversations)
new_assistant_agent_outputs = list(zip(cur_chatting_idxs, assistant_agent_outputs))
func_call_pattern = r"<func_call>(.*?)</func_call>"
func_calls = []
final_answer_pattern = r"<final>(.*?)</final>"
for idx, text in new_assistant_agent_outputs:
if isinstance(text, str):
final_match = re.search(final_answer_pattern, text, re.DOTALL)
else:
self.logger.warning("Warning: 'text' is not a string:", text)
final_match = None
valid_label[idx] = 1
continue
final_match = re.search(final_answer_pattern, text, re.DOTALL)
if final_match:
completed_label[idx] = 1
self.logger.info(f"Final answer found: {idx}")
continue
func_match = re.search(func_call_pattern, text, re.DOTALL)
if func_match:
result = func_match.group(1)
func_calls.append(f"<func_call>{result}</func_call>")
else:
func_calls.append("")
for item, text in zip(cur_chatting_conversations, assistant_agent_outputs):
item.append({"role": "assistant", "content": text})
not_completed_idxs = np.where(np.array(completed_label) == 0)[0]
valid_idxs = np.where(np.array(valid_label) == 0)[0]
cur_chatting_idxs = np.intersect1d(not_completed_idxs, valid_idxs)
cur_chatting_conversations = [assistant_agent_inputs[idx] for idx in cur_chatting_idxs]
tool_agent_inputs = self._reformat_tool_agent_prompt(func_calls)
tool_agent_outputs = self.llm_serving.generate_from_input(tool_agent_inputs)
for item, text in zip(cur_chatting_conversations, tool_agent_outputs):
item.append({"role": "assistant", "content": text})
turns += 1
if turns >= 5:
break
self.logger.info(f"Bad answer {np.where(np.array(completed_label) == 0)[0]}")
dataframe[self.output_key] = cur_conversations
dataframe = dataframe[np.array(completed_label) == 1]
storage.write(dataframe)
return self.output_key
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment