Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
import re
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class HtmlEntityRefiner(OperatorABC):
def __init__(self, html_entities: list = [
"nbsp", "lt", "gt", "amp", "quot", "apos", "hellip", "ndash", "mdash",
"lsquo", "rsquo", "ldquo", "rdquo"
]):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
# 从参数中获取自定义 HTML 实体列表,如果未提供则使用默认列表
self.html_entities = html_entities
# 构建正则表达式模式,匹配所有定义的 HTML 实体
# 包括以下几种形式:
# 1. &实体名;
# 2. &实体名; (全角 &)
# 3. &实体名; (中文分号)
# 4. &实体名; (全角 & + 中文分号)
entity_patterns = []
for entity in self.html_entities:
# &实体名;
entity_patterns.append(fr'&{entity};')
# &实体名; (全角 &)
entity_patterns.append(fr'&{entity};')
# &实体名; (中文分号)
entity_patterns.append(fr'&{entity};')
# &实体名; (全角 & + 中文分号)
entity_patterns.append(fr'&{entity};')
# 编译正则表达式
self.html_entity_regex = re.compile('|'.join(entity_patterns))
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"去除文本中的HTML实体,包括标准实体(如 、<)和各种变体形式(全角符号、中文分号等)。支持自定义需要移除的HTML实体列表。"
"输入参数:\n"
"- html_entities:需要移除的HTML实体列表,默认为包含常见实体的列表\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 包含移除HTML实体后文本的DataFrame\n"
"- 返回输入字段名,用于后续算子引用"
)
elif lang == "en":
return (
"Remove HTML entities from text, including standard entities (e.g.,  , <) and various variants (full-width symbols, Chinese semicolons, etc.). \n"
"Supports custom list of HTML entities to be removed.\n"
"Input Parameters:\n"
"- html_entities: List of HTML entities to remove, default contains common entities\n"
"- input_key: Field name for input text\n\n"
"Output Parameters:\n"
"- DataFrame containing text with HTML entities removed\n"
"- Returns input field name for subsequent operator reference"
)
else:
return (
"HtmlEntityRefiner removes HTML entities and their variants from text."
)
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = original_text
# 使用正则表达式替换所有匹配的HTML实体为空字符串
refined_text = self.html_entity_regex.sub('', refined_text)
# 检查文本是否被修改
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class HtmlUrlRemoverRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"去除文本中的URL链接和HTML标签,净化文本内容。使用正则表达式匹配并移除各种形式的URL和HTML标签。"
"输入参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 包含净化后文本的DataFrame\n"
"- 返回输入字段名,用于后续算子引用"
)
elif lang == "en":
return (
"Remove URL links and HTML tags from text to clean content. Uses regular expressions to match and remove various forms of URLs and HTML tags.\n"
"Input Parameters:\n"
"- input_key: Field name for input text\n\n"
"Output Parameters:\n"
"- DataFrame containing cleaned text\n"
"- Returns input field name for subsequent operator reference"
)
else:
return (
"HtmlUrlRemoverRefiner cleans text by removing URLs and HTML tags."
)
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = original_text
# Remove URLs
refined_text = re.sub(r'https?:\/\/\S+[\r\n]*', '', refined_text, flags=re.MULTILINE)
# Remove HTML tags
refined_text = re.sub(r'<.*?>', '', refined_text)
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class LowercaseRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"将文本字段中的所有大写字符转换为小写,统一文本格式。对指定字段的文本内容进行全小写处理。"
"输入参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 包含小写转换后文本的DataFrame\n"
"- 返回输入字段名,用于后续算子引用"
)
elif lang == "en":
return (
"Convert all uppercase characters in text fields to lowercase to unify text format. Applies full lowercase processing to text content of specified fields.\n"
"Input Parameters:\n"
"- input_key: Field name for input text\n\n"
"Output Parameters:\n"
"- DataFrame containing lowercase converted text\n"
"- Returns input field name for subsequent operator reference"
)
else:
return (
"LowercaseRefiner converts text fields to lowercase."
)
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
lower_text = original_text.lower()
if original_text != lower_text:
item = lower_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {lower_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import spacy
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
ENTITY_LABELS = {
"PERSON": "[PERSON]",
"ORG": "[ORG]",
"GPE": "[GPE]",
"LOC": "[LOC]",
"PRODUCT": "[PRODUCT]",
"EVENT": "[EVENT]",
"DATE": "[DATE]",
"TIME": "[TIME]",
"MONEY": "[MONEY]",
"PERCENT": "[PERCENT]",
"QUANTITY": "[QUANTITY]",
"ORDINAL": "[ORDINAL]",
"CARDINAL": "[CARDINAL]",
"NORP": "[NORP]",
"FAC": "[FAC]",
"LAW": "[LAW]",
"LANGUAGE": "[LANGUAGE]",
"WORK_OF_ART": "[WORK_OF_ART]",
"LAW": "[LAW]",
"ORDINAL": "[ORDINAL]",
"CARDINAL": "[CARDINAL]",
"PERCENT": "[PERCENT]",
"QUANTITY": "[QUANTITY]",
"DATE": "[DATE]",
"TIME": "[TIME]",
"URL": "[URL]",
"EMAIL": "[EMAIL]",
"MONEY": "[MONEY]",
"FAC": "[FAC]",
"PRODUCT": "[PRODUCT]",
"EVENT": "[EVENT]",
"WORK_OF_ART": "[WORK_OF_ART]",
"LANGUAGE": "[LANGUAGE]",
"NORP": "[NORP]"
}
@OPERATOR_REGISTRY.register()
class NERRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
self.nlp = spacy.load("en_core_web_sm")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"使用命名实体识别(NER)技术识别并屏蔽文本中的特定实体。使用spaCy的'en_core_web_sm'模型识别实体,并将其替换为对应的实体类型标签。"
"输入参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 包含实体屏蔽后文本的DataFrame\n"
"- 返回输入字段名,用于后续算子引用"
)
elif lang == "en":
return (
"Mask specific entities in text using Named Entity Recognition (NER) technology. Uses spaCy's 'en_core_web_sm' model to identify entities \n"
"and replace them with corresponding entity type tags.\n"
"Input Parameters:\n"
"- input_key: Field name for input text\n\n"
"Output Parameters:\n"
"- DataFrame containing text with masked entities\n"
"- Returns input field name for subsequent operator reference"
)
else:
return (
"NERRefiner masks specific entities in text using Named Entity Recognition."
)
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = original_text
doc = self.nlp(refined_text)
for ent in doc.ents:
if ent.label_ in ENTITY_LABELS :
refined_text = refined_text.replace(ent.text, f"[{ent.label_}]")
if original_text != refined_text:
item = refined_text
modified = True
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
from tqdm import tqdm
from transformers import AutoModelForTokenClassification, AutoTokenizer
from presidio_analyzer.nlp_engine import TransformersNlpEngine
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class PIIAnonymizeRefiner(OperatorABC):
def __init__(self, lang='en', device='cuda', model_cache_dir='./dataflow_cache', model_name='dslim/bert-base-NER', ):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
self.lang = lang
self.device = device
self.model_cache_dir = model_cache_dir
self.model_name = model_name
model_name = 'dslim/bert-base-NER'
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
self.model = AutoModelForTokenClassification.from_pretrained(model_name, cache_dir=self.model_cache_dir).to(self.device)
model_config = [{
"lang_code": self.lang,
"model_name": {
"spacy": "en_core_web_sm",
"transformers": model_name
}
}]
self.nlp_engine = TransformersNlpEngine(models=model_config)
self.analyzer = AnalyzerEngine(nlp_engine=self.nlp_engine)
self.anonymizer = AnonymizerEngine()
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"使用Presidio和BERT-NER模型识别并匿名化文本中的个人身份信息(PII)。支持多种PII类型的检测和匿名化处理。"
"输入参数:\n"
"- lang:语言代码,默认为'en'\n"
"- device:运行设备,默认为'cuda'\n"
"- model_cache_dir:模型缓存目录,默认为'./dataflow_cache'\n"
"- model_name:NER模型名称,默认为'dslim/bert-base-NER'\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 包含匿名化后文本的DataFrame\n"
"- 返回输入字段名,用于后续算子引用"
)
elif lang == "en":
return (
"Identify and anonymize Personally Identifiable Information (PII) in text using Presidio and BERT-NER models. Supports detection and anonymization of various PII types.\n"
"Input Parameters:\n"
"- lang: Language code, default is 'en'\n"
"- device: Running device, default is 'cuda'\n"
"- model_cache_dir: Model cache directory, default is './dataflow_cache'\n"
"- model_name: NER model name, default is 'dslim/bert-base-NER'\n"
"- input_key: Field name for input text\n\n"
"Output Parameters:\n"
"- DataFrame containing anonymized text\n"
"- Returns input field name for subsequent operator reference"
)
else:
return (
"PIIAnonymizeRefiner identifies and anonymizes PII in text using NLP models."
)
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
anonymized_count = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
results = self.analyzer.analyze(original_text, language=self.lang)
anonymized_text = self.anonymizer.anonymize(original_text, results)
if original_text != anonymized_text.text:
item = anonymized_text.text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {anonymized_text.text[:30]}...")
refined_data.append(item)
if modified:
anonymized_count += 1
self.logger.debug(f"Item modified, total modified so far: {anonymized_count}")
self.logger.info(f"Refining Complete. Total modified items: {anonymized_count}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class ReferenceRemoverRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__}...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"删除文本中未闭合的引用标签和引用链接,包括<ref>标签和{{cite}}模板的各种完整和不完整形式。净化文本中的引用标记。"
"输入参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 包含移除引用标记后文本的DataFrame\n"
"- 返回输入字段名,用于后续算子引用"
)
elif lang == "en":
return (
"Remove unclosed reference tags and citation links from text, including various complete and incomplete forms of <ref> tags and {{cite}} templates. \n"
"Cleans reference markers from text.\n"
"Input Parameters:\n"
"- input_key: Field name for input text\n\n"
"Output Parameters:\n"
"- DataFrame containing text with reference markers removed\n"
"- Returns input field name for subsequent operator reference"
)
else:
return (
"ReferenceRemoverRefiner removes reference tags and citation links from text."
)
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
# 定义要删除的模式 - 更全面的版本
# 1. 所有<ref>标签及其内容(包括各种不完整形式)
ref_pattern = re.compile(
r'<ref\b[^>]*>.*?</ref>|' # 完整的ref标签
r'<ref\b[^>]*>[^<]*$|' # 不完整的ref标签(没有闭合)
r'<ref\b[^>]*>.*?/br' # ref标签后跟/br(如你示例中的情况)
)
# 2. 所有{{cite}}模板及其内容(包括各种不完整形式)
cite_pattern = re.compile(
r'\{\{cite\s+\w+\|[^}]*\}\}|' # 完整的cite模板
r'\{\{cite\s+\w+\|[^}]*$' # 不完整的cite模板(没有闭合)
)
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = original_text
# 删除所有未闭合的ref标签
refined_text, ref_count = ref_pattern.subn('', refined_text)
# 删除所有不完整的cite模板
refined_text, cite_count = cite_pattern.subn('', refined_text)
# 检查是否有任何修改
if ref_count > 0 or cite_count > 0:
modified = True
numbers += 1
self.logger.debug(f"Item modified, removed {ref_count} ref tags and {cite_count} cite templates")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import contractions
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveContractionsRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于扩展文本中的英语缩写词,将缩写形式转换为完整形式(例如将\"can't\"扩展为\"cannot\")。\n"
"使用contractions库进行缩写词扩展,提高文本标准化程度。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含扩展缩写词后的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator expands English contractions in text, converting abbreviated forms to full forms (e.g., \"can't\"\"cannot\").\n"
"Uses the contractions library for abbreviation expansion to improve text standardization.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with expanded contractions\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Expands English contractions in text to improve standardization."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
expanded_text = contractions.fix(original_text)
if original_text != expanded_text:
item = expanded_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {expanded_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveEmojiRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.refiner_name = 'RemoveEmojiRefiner'
self.logger.info(f"Initializing {self.__class__.__name__} ...")
# Emoji pattern for matching emojis in the text
self.emoji_pattern = re.compile(
"["
u"\U0001F600-\U0001F64F" # Emoticons
u"\U0001F300-\U0001F5FF" # Miscellaneous symbols and pictographs
u"\U0001F680-\U0001F6FF" # Transport and map symbols
u"\U0001F1E0-\U0001F1FF" # Flags
u"\U00002702-\U000027B0" # Dingbats
"]+",
flags=re.UNICODE
)
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于去除文本中的Unicode图像表情符号,包括表情符号、杂项符号、交通符号、旗帜等各类图像符号。\n"
"通过正则表达式匹配Unicode表情符号范围,实现高效过滤。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含去除表情符号的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes Unicode emojis from text, including emoticons, symbols, transport icons, flags and other image symbols.\n"
"Achieves efficient filtering by matching Unicode emoji ranges with regular expressions.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with emojis removed\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes Unicode emojis from text using regular expression matching."
def run(self, storage: DataFlowStorage, input_key: str):
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
self.logger.info(f"Running {self.__class__.__name__} with input_key = {input_key}...")
for item in tqdm(dataframe[input_key], desc=f"Implementing {self.refiner_name}"):
modified = False
original_text = item
no_emoji_text = self.emoji_pattern.sub(r'', original_text)
if original_text != no_emoji_text:
item = no_emoji_text
modified = True
self.logger.debug(f"Modified text for key '{input_key}': Original: {original_text[:30]}... -> Refined: {no_emoji_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
dataframe[input_key] = refined_data
storage.write(dataframe)
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
return [input_key]
This source diff could not be displayed because it is too large. You can view the blob instead.
import re
from tqdm import tqdm
from dataflow.utils.storage import DataFlowStorage
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveExtraSpacesRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于移除文本中的多余空格,将连续的多个空格替换为单个空格,并去除文本前后的空白字符。\n"
"通过字符串分割和连接实现空格标准化,提高文本格式一致性。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含标准化空格的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes extra spaces from text, replacing consecutive spaces with single spaces and trimming leading/trailing whitespace.\n"
"Achieves space standardization through string splitting and joining to improve text format consistency.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with standardized spacing\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes extra spaces and normalizes spacing in text."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = " ".join(original_text.split()) # Remove extra spaces
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
dataframe[self.input_key] = refined_data
storage.write(dataframe)
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
return [self.input_key]
import re
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveImageRefsRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.image_pattern = re.compile(
r'!\[\]\(images\/[0-9a-fA-F]\.jpg\)|'
r'[a-fA-F0-9]+\.[a-zA-Z]{3,4}\)|'
r'!\[\]\(images\/[a-f0-9]|'
r'图\s+\d+-\d+:[\u4e00-\u9fa5a-zA-Z0-9]+|'
r'(?:[0-9a-zA-Z]+){7,}|' # 正则5
r'(?:[一二三四五六七八九十零壹贰叁肆伍陆柒捌玖拾佰仟万亿]+){5,}|' # 正则6(汉字数字)
r"u200e|"
r"&#247;|\? :|"
r"[�□]|\{\/U\}|"
r"U\+26[0-F][0-D]|U\+273[3-4]|U\+1F[3-6][0-4][0-F]|U\+1F6[8-F][0-F]"
)
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于去除文本中的图片引用格式,包括Markdown图片链接、图片编号、特殊符号组合等图像引用模式。\n"
"通过多模式正则表达式匹配,识别并移除多种图片引用格式。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含去除图片引用的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes image reference formats from text, including Markdown image links, image numbers, special symbol combinations and other image reference patterns.\n"
"Identifies and removes multiple image reference formats through multi-pattern regular expression matching.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with image references removed\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes image reference formats from text using regular expressions."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
# 移除所有图片引用格式[1,2](@ref)
cleaned_text = self.image_pattern.sub('', original_text)
if original_text != cleaned_text:
item = cleaned_text
modified = True
# 调试日志:显示修改前后的对比
self.logger.debug(f"Modified text for key '{self.input_key}':")
self.logger.debug(f"Original: {original_text[:100]}...")
self.logger.debug(f"Refined : {cleaned_text[:100]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveNumberRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于移除文本中的数字字符,包括0-9的阿拉伯数字。\n"
"通过字符过滤实现数字移除,保留纯文本内容。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含去除数字的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes numeric characters from text, including 0-9 Arabic numerals.\n"
"Implements number removal through character filtering to retain pure text content.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with numbers removed\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes numeric characters from text through character filtering."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
no_number_text = ''.join([char for char in original_text if not char.isdigit()])
if original_text != no_number_text:
item = no_number_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {no_number_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import string
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemovePunctuationRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
self.punct_to_remove = string.punctuation
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于移除文本中的标点符号,包括英文标点符号集合中的所有符号。\n"
"使用string.punctuation定义的标点集合进行过滤,实现文本去标点处理。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含去除标点的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes punctuation from text, including all symbols in the English punctuation set.\n"
"Uses the punctuation set defined by string.punctuation for filtering to achieve text punctuation removal.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with punctuation removed\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes punctuation from text using string.punctuation set."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
no_punct_text = original_text.translate(str.maketrans('', '', self.punct_to_remove))
if original_text != no_punct_text:
item = no_punct_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {no_punct_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
import string
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveRepetitionsPunctuationRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
self.punct_to_remove = string.punctuation
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于移除文本中重复的标点符号,例如将\"!!!\"变为\"!\"\",,\"变为\",\"\n"
"通过正则表达式匹配连续重复的标点符号,替换为单个符号。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含标准化标点的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes repeated punctuation in text, e.g., changing \"!!!\" to \"!\", \",,\" to \",\".\n"
"Matches consecutive repeated punctuation using regular expressions and replaces with single symbols.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with normalized punctuation\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes repeated punctuation in text using regular expressions."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
no_extra_punct_text = re.sub(r'([^\w\s_])\1+|(_)\2+', r'\1\2', original_text)
if original_text != no_extra_punct_text:
item = no_extra_punct_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {no_extra_punct_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import nltk
import os
from nltk.corpus import stopwords
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class RemoveStopwordsRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
# 设置 NLTK 数据路径(如果环境变量中有的话)
if 'NLTK_DATA' in os.environ:
nltk.data.path.insert(0, os.environ['NLTK_DATA'])
# 尝试下载,如果已存在则跳过
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
def remove_stopwords(self, text):
words = text.split()
stopwords_list = set(stopwords.words('english'))
refined_words = [word for word in words if word.lower() not in stopwords_list]
return " ".join(refined_words)
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于移除文本中的英语停用词(如\"the\"\"is\"\"in\"等无实际意义的高频词汇)。\n"
"使用NLTK库的stopwords语料库进行停用词过滤,提高文本特征密度。\n"
"输入参数:\n"
"- model_cache_dir:模型缓存目录,默认为'./dataflow_cache'\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含去除停用词的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator removes English stopwords from text (e.g., high-frequency words with little meaning like \"the\", \"is\", \"in\").\n"
"Uses NLTK library's stopwords corpus for stopword filtering to improve text feature density.\n"
"Input Parameters:\n"
"- model_cache_dir: Model cache directory, default is './dataflow_cache'\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with stopwords removed\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Removes English stopwords from text using NLTK's stopwords corpus."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = self.remove_stopwords(original_text)
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
import os
import requests
from tqdm import tqdm
from symspellpy.symspellpy import SymSpell, Verbosity
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class SpellingCorrectionRefiner(OperatorABC):
def __init__(self, max_edit_distance: int = 2, prefix_length: int = 7, dictionary_path: str = "frequency_dictionary_en_82_765.txt"):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
self.max_edit_distance = max_edit_distance # Default to 2 if not specified
self.prefix_length = prefix_length # Default to 7 if not specified
self.dictionary_path = dictionary_path
# If dictionary is not found locally, download it
if not os.path.exists(self.dictionary_path):
self.download_dictionary()
self.sym_spell = SymSpell(max_dictionary_edit_distance=self.max_edit_distance, prefix_length=self.prefix_length)
term_index = 0
count_index = 1
if not self.sym_spell.load_dictionary(self.dictionary_path, term_index, count_index):
self.logger.error(f"Error loading dictionary at {self.dictionary_path}")
self.logger.info(f"Successfully loaded dictionary at {self.dictionary_path}")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于通过SymSpell算法对文本中的拼写错误进行纠正,支持自定义编辑距离和词典路径。\n"
"若本地词典不存在则自动下载,使用近似字符串匹配实现拼写纠错功能。\n"
"输入参数:\n"
"- max_edit_distance:最大编辑距离,默认为2\n"
"- prefix_length:前缀长度,默认为7\n"
"- dictionary_path:词典路径,默认为'frequency_dictionary_en_82_765.txt'\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含纠正拼写错误的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator corrects spelling errors in text using the SymSpell algorithm, supporting custom edit distance and dictionary path.\n"
"Automatically downloads dictionary if not locally available, using approximate string matching for spelling correction.\n"
"Input Parameters:\n"
"- max_edit_distance: Maximum edit distance, default is 2\n"
"- prefix_length: Prefix length, default is 7\n"
"- dictionary_path: Dictionary path, default is 'frequency_dictionary_en_82_765.txt'\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with corrected spelling\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Corrects spelling errors in text using the SymSpell algorithm with configurable parameters."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = self.spelling_checks(original_text)
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
def spelling_checks(self, text):
correct_result = []
for word in text.split():
suggestions = self.sym_spell.lookup(word, Verbosity.CLOSEST, self.max_edit_distance)
corrected_word = suggestions[0].term if suggestions else word
correct_result.append(corrected_word)
return " ".join(correct_result)
def download_dictionary(self):
url = 'https://raw.githubusercontent.com/mammothb/symspellpy/master/symspellpy/frequency_dictionary_en_82_765.txt'
try:
print("Downloading dictionary...")
response = requests.get(url)
response.raise_for_status()
with open(self.dictionary_path, 'wb') as file:
file.write(response.content)
print(f"Dictionary downloaded and saved to {self.dictionary_path}")
except requests.exceptions.RequestException as e:
print(f"Error downloading dictionary: {e}")
\ No newline at end of file
import nltk
from nltk.stem import PorterStemmer, WordNetLemmatizer
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class StemmingLemmatizationRefiner(OperatorABC):
def __init__(self, method: str = "stemming"):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
self.method = method.lower()
if self.method not in ["stemming", "lemmatization"]:
raise ValueError("Invalid method. Choose 'stemming' or 'lemmatization'.")
nltk.download('wordnet')
nltk.download('omw-1.4')
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对文本进行词干提取或词形还原处理,将词语转换为其基本形式。\n"
"支持两种处理方式:Porter词干提取(stemming)和WordNet词形还原(lemmatization),可通过参数选择。\n"
"输入参数:\n"
"- method:处理方法,可选'stemming'或'lemmatization',默认为'stemming'\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含词干/词形还原后的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator applies stemming or lemmatization to text, converting words to their base forms.\n"
"Supports two processing methods: Porter stemming and WordNet lemmatization, selectable via parameter.\n"
"Input Parameters:\n"
"- method: Processing method, optional 'stemming' or 'lemmatization', default is 'stemming'\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with stemming/lemmatization applied\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Applies stemming or lemmatization to text using NLTK's PorterStemmer or WordNetLemmatizer."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
dataframe = storage.read("dataframe")
numbers = 0
refined_data = []
stemmer = PorterStemmer()
lemmatizer = WordNetLemmatizer()
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
if self.method == "stemming":
refined_text = " ".join([stemmer.stem(word) for word in original_text.split()])
elif self.method == "lemmatization":
refined_text = " ".join([lemmatizer.lemmatize(word) for word in original_text.split()])
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
import re
from datetime import datetime
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.registry import OPERATOR_REGISTRY
@OPERATOR_REGISTRY.register()
class TextNormalizationRefiner(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f"Initializing {self.__class__.__name__} ...")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于规范化文本中的日期格式和货币格式,统一为标准表示形式。\n"
"日期格式统一转换为'YYYY-MM-DD'形式,货币格式转换为'金额 USD'形式,提高数据一致性。\n"
"输入参数:\n"
"- 无初始化参数\n"
"运行参数:\n"
"- input_key:输入文本字段名\n"
"输出参数:\n"
"- 处理后的DataFrame,包含格式规范化的文本\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator normalizes date formats and currency formats in text to standard representations.\n"
"Unifies date formats to 'YYYY-MM-DD' and currency formats to 'amount USD' to improve data consistency.\n"
"Input Parameters:\n"
"- No initialization parameters\n"
"Runtime Parameters:\n"
"- input_key: Input text field name\n"
"Output Parameters:\n"
"- Processed DataFrame containing text with normalized formats\n"
"- List containing input field name for subsequent operator reference"
)
else:
return "Normalizes date and currency formats in text to standard representations."
def run(self, storage: DataFlowStorage, input_key: str):
self.input_key = input_key
dataframe = storage.read("dataframe")
self.logger.info(f"Running {self.__class__.__name__} with input_key = {self.input_key}...")
numbers = 0
refined_data = []
for item in tqdm(dataframe[self.input_key], desc=f"Implementing {self.__class__.__name__}"):
modified = False
original_text = item
refined_text = original_text
refined_text = re.sub(r'(\d{1,2})[/.](\d{1,2})[/.](\d{2,4})', r'\3-\2-\1', refined_text)
date_patterns = [
(r'\b(\w+)\s+(\d{1,2}),\s+(\d{4})\b', '%B %d, %Y'),
(r'\b(\d{1,2})\s+(\w+)\s+(\d{4})\b', '%d %B %Y')
]
for pattern, date_format in date_patterns:
match = re.search(pattern, refined_text)
if match:
date_str = match.group(0)
try:
parsed_date = datetime.strptime(date_str, date_format)
refined_text = refined_text.replace(date_str, parsed_date.strftime('%Y-%m-%d'))
except ValueError:
pass
refined_text = re.sub(r'\$\s?(\d+)', r'\1 USD', refined_text)
if original_text != refined_text:
item = refined_text
modified = True
self.logger.debug(f"Modified text for key '{self.input_key}': Original: {original_text[:30]}... -> Refined: {refined_text[:30]}...")
refined_data.append(item)
if modified:
numbers += 1
self.logger.debug(f"Item modified, total modified so far: {numbers}")
self.logger.info(f"Refining Complete. Total modified items: {numbers}")
dataframe[self.input_key] = refined_data
output_file = storage.write(dataframe)
return [self.input_key]
\ No newline at end of file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
# filter
from .generate.kbc_chunk_generator import KBCChunkGenerator
from .generate.kbc_chunk_generator_batch import KBCChunkGeneratorBatch
from .generate.file_or_url_to_markdown_converter import FileOrURLToMarkdownConverter
from .generate.file_or_url_to_markdown_converter_batch import FileOrURLToMarkdownConverterBatch
from .generate.kbc_text_cleaner import KBCTextCleaner
from .generate.kbc_text_cleaner_batch import KBCTextCleanerBatch
from .generate.mathbook_question_extract import MathBookQuestionExtract
# from .generate.kbc_multihop_qa_generator import KBCMultiHopQAGenerator
from .generate.kbc_multihop_qa_generator_batch import KBCMultiHopQAGeneratorBatch
from .generate.qa_extract import QAExtractor
else:
import sys
from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking
cur_path = "dataflow/operators/knowledge_cleaning/"
_import_structure = generate_import_structure_from_type_checking(__file__, cur_path)
sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/knowledge_cleaning/", _import_structure)
import pandas as pd
from typing import Literal
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
import os
from pathlib import Path
from trafilatura import fetch_url, extract
import requests
def _parse_file_with_mineru(raw_file: str, output_file: str, mineru_backend: Literal["vlm-sglang-engine", "pipeline"] = "vlm-sglang-engine") -> str:
"""
Uses MinerU to parse PDF/image files (pdf/png/jpg/jpeg/webp/gif) into Markdown files.
Internally, the parsed outputs for each item are stored in a structured directory:
'intermediate_dir/pdf_name/MinerU_Version[mineru_backend]'.
This directory stores various MinerU parsing outputs, and you can customize
which content to extract based on your needs.
Args:
raw_file: Input file path, supports .pdf/.png/.jpg/.jpeg/.webp/.gif
output_file: Full path for the output Markdown file
mineru_backend: Sets the backend engine for MinerU. Options include:
- "pipeline": Traditional pipeline processing (MinerU1)
- "vlm-sglang-engine": New engine based on multimodal language models (MinerU2) (default recommended)
Choose the appropriate backend based on your needs. Defaults to "vlm-sglang-engine".
For more details, refer to the MinerU GitHub: https://github.com/opendatalab/MinerU.
Returns:
output_file: Path to the Markdown file
"""
try:
import mineru
except ImportError:
raise Exception(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
logger=get_logger()
os.environ['MINERU_MODEL_SOURCE'] = "local" # 可选:从本地加载模型
MinerU_Version = {"pipeline": "auto", "vlm-sglang-engine": "vlm"}
raw_file = Path(raw_file)
pdf_name = raw_file.stem
intermediate_dir = output_file
try:
return_code = os.system(
f"mineru -p {raw_file} -o {intermediate_dir} -b {mineru_backend} --source local"
)
if return_code != 0:
raise RuntimeError(f"MinerU execution failed with return code: {return_code}")
except Exception as e:
raise RuntimeError(f"Failed to process file with MinerU: {str(e)}")
# Directory for storing raw data, including various MinerU parsing outputs.
# You can customize which content to extract based on your needs.
PerItemDir = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend])
output_file = os.path.join(PerItemDir, f"{pdf_name}.md")
logger.info(f"Markdown saved to: {output_file}")
return output_file
def _parse_doc_to_md(input_file: str, output_file: str):
"""
support conversion of doc/ppt/pptx/pdf files to markdowns
"""
try:
from magic_doc.docconv import DocConverter
except:
raise Exception(
"""
Fairy-doc is not installed in this environment yet.
Please refer to https://github.com/opendatalab/magic-doc to install.
Or you can just execute 'apt-get/yum/brew install libreoffice' and 'pip install fairy-doc[gpu]' to fix this error.
please make sure you have gpu on your machine.
"""
)
logger=get_logger()
converter = DocConverter(s3_config=None)
markdown_content, time_cost = converter.convert(input_file, conv_timeout=300)
logger.info("time cost: ", time_cost)
with open(output_file, "w",encoding='utf-8') as f:
f.write(markdown_content)
return output_file
def _parse_xml_to_md(raw_file:str=None, url:str=None, output_file:str=None):
logger=get_logger()
if(url):
downloaded=fetch_url(url)
elif(raw_file):
with open(raw_file, "r", encoding='utf-8') as f:
downloaded=f.read()
else:
raise Exception("Please provide at least one of file path and url string.")
try:
result=extract(downloaded, output_format="markdown", with_metadata=True)
logger.info(f"Extracted content is written into {output_file}")
with open(output_file,"w", encoding="utf-8") as f:
f.write(result)
except Exception as e:
logger.error("Error during extract this file or link: ", e)
return output_file
def is_pdf_url(url):
try:
# 发送HEAD请求,只获取响应头,不下载文件
response = requests.head(url, allow_redirects=True)
# 如果响应的Content-Type是application/pdf
if response.status_code == 200 and response.headers.get('Content-Type') == 'application/pdf':
return True
else:
print(f"Content-Type: {response.headers.get('Content-Type')}")
return False
except requests.exceptions.RequestException:
# 如果请求失败,返回False
print("Request failed")
return False
def download_pdf(url, save_path):
try:
# 发送GET请求下载PDF文件
response = requests.get(url, stream=True)
# 确保响应内容是PDF
if response.status_code == 200 and response.headers.get('Content-Type') == 'application/pdf':
# 将PDF保存到本地
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print(f"PDF saved to {save_path}")
else:
print("The URL did not return a valid PDF file.")
except requests.exceptions.RequestException as e:
print(f"Error downloading PDF: {e}")
@OPERATOR_REGISTRY.register()
class FileOrURLToMarkdownConverter(OperatorABC):
"""
mineru_backend sets the backend engine for MinerU. Options include:
- "pipeline": Traditional pipeline processing (MinerU1)
- "vlm-sglang-engine": New engine based on multimodal language models (MinerU2) (default recommended)
Choose the appropriate backend based on your needs. Defaults to "vlm-sglang-engine".
For more details, refer to the MinerU GitHub: https://github.com/opendatalab/MinerU.
"""
def __init__(
self,
url: str = None,
raw_file: str = None,
intermediate_dir: str = "intermediate",
lang: str = "en",
mineru_backend: Literal["vlm-sglang-engine", "pipeline"] = "vlm-sglang-engine",
):
self.logger = get_logger()
self.intermediate_dir=intermediate_dir
os.makedirs(self.intermediate_dir, exist_ok=True)
self.lang=lang
self.mineru_backend = mineru_backend
self.url = url
self.raw_file = raw_file
@staticmethod
def get_desc(lang: str = "zh"):
"""
返回算子功能描述 (根据run()函数的功能实现)
"""
if lang == "zh":
return (
"知识提取算子:支持从多种文件格式中提取结构化内容并转换为标准Markdown\n"
"核心功能:\n"
"1. PDF文件:使用MinerU解析引擎提取文本/表格/公式,保留原始布局\n"
"2. Office文档(DOC/PPT等):通过DocConverter转换为Markdown格式\n"
"3. 网页内容(HTML/XML):使用trafilatura提取正文并转为Markdown\n"
"4. 纯文本(TXT/MD):直接透传不做处理\n"
"特殊处理:\n"
"- 自动识别中英文文档(lang参数)\n"
"- 支持本地文件路径和URL输入\n"
"- 生成中间文件到指定目录(intermediate_dir)"
)
else: # 默认英文
return (
"Knowledge Extractor: Converts multiple file formats to structured Markdown\n"
"Key Features:\n"
"1. PDF: Uses MinerU engine to extract text/tables/formulas with layout preservation\n"
"2. Office(DOC/PPT): Converts to Markdown via DocConverter\n"
"3. Web(HTML/XML): Extracts main content using trafilatura\n"
"4. Plaintext(TXT/MD): Directly passes through without conversion\n"
"Special Handling:\n"
"- Auto-detects Chinese/English documents(lang param)\n"
"- Supports both local files and URLs\n"
"- Generates intermediate files to specified directory(intermediate_dir)"
)
def run(self, storage: DataFlowStorage, input_key="", output_key=""):
self.logger.info("Starting extraction...")
self.logger.info("If you're providing a URL or a large file, this may take a while. Please wait...")
# Handle extraction from URL
if self.url:
if is_pdf_url(self.url):
pdf_save_path = output_file = os.path.join(
os.path.dirname(storage.first_entry_file_name),
"raw/crawled.pdf"
)
self.logger.info(f"Downloading PDF from {self.url} to {pdf_save_path}")
download_pdf(self.url, pdf_save_path)
self.raw_file=pdf_save_path
self.logger.info(f"pdf file has been fetched and saved to {pdf_save_path}")
else:
output_file = os.path.join(
os.path.dirname(storage.first_entry_file_name),
"raw/crawled.md"
)
output_file = _parse_xml_to_md(url=self.url, output_file=output_file)
self.logger.info(f"Primary extracted result written to: {output_file}")
return output_file
# Handle supported file types
# Extract file name and extension
raw_file_name = os.path.splitext(os.path.basename(self.raw_file))[0]
raw_file_suffix = os.path.splitext(self.raw_file)[1].lower()
raw_file_suffix_no_dot = raw_file_suffix.lstrip(".")
# Define default output path
output_file = os.path.join(
self.intermediate_dir,
f"{raw_file_name}_{raw_file_suffix_no_dot}.md"
)
if raw_file_suffix in [".pdf", ".png", ".jpg", ".jpeg", ".webp", ".gif"]:
self.logger.info(f"Using MinerU backend: {self.mineru_backend}")
# Use MinerU backend for PDF and image files
output_file = _parse_file_with_mineru(
raw_file=self.raw_file,
output_file=self.intermediate_dir,
mineru_backend=self.mineru_backend
)
elif raw_file_suffix in [".doc", ".docx", ".ppt", ".pptx"]:
# .doc format is currently not supported
if raw_file_suffix == ".doc":
raise Exception(
"Function under maintenance. Please convert your file to PDF format first."
)
# Handling for .docx, .pptx, and .ppt can be added here if needed
elif raw_file_suffix in [".html", ".xml"]:
# Use XML/HTML parser for HTML and XML files
output_file = _parse_xml_to_md(raw_file=self.raw_file, output_file=output_file)
elif raw_file_suffix in [".txt", ".md"]:
# Plain text and markdown files require no processing
output_file = self.raw_file
else:
# Unsupported file type
raise Exception(f"Unsupported file type: {raw_file_suffix}")
self.logger.info(f"Primary extracted result written to: {output_file}")
return output_file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment