Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
from dataflow.utils.registry import PROMPT_REGISTRY
from dataflow.core.prompt import PromptABC
'''
A collection of prompts for the general reasoning operator.
'''
@PROMPT_REGISTRY.register()
class GeneralAnswerGeneratorPrompt(PromptABC):
'''
The prompt for the answer generator.
'''
def __init__(self):
pass
def build_prompt(self, question: str) -> str:
"""
for general reasoning answer generation
"""
prompt = (
r'''You are an intelligent chatbot designed for producing the answer to the given reasoning task.
Remember: DO NOT output anything else, only output the answer you generate.
Generate a solution to the given task strictly following this format:
1. Identify key components and premises of the task
2. Apply relevant principles, theorems, or methods with step-by-step derivation or argument
3. Perform any necessary calculations or logical checks with intermediate verification
4. Present the final answer or conclusion in a clear, unambiguous notation
Format Requirements:
- Prefix each step with "→" (use the actual arrow symbol, not its Unicode escape sequence)
- Ensure all symbols and special characters are presented using appropriate markup (e.g., LaTeX commands for mathematical symbols, code formatting for code snippets)
Example Template:
Task: Analyze the time complexity of the following sorting algorithm and prove its correctness.
Solution:
1. Identify components:
→ Algorithm uses divide-and-conquer to split the list in half
→ Merging step compares elements pairwise
2. Apply principles:
→ Recurrence: T(n) = 2T(n/2) + O(n)
→ By Master Theorem, T(n) = O(n log n)
3. Verification:
→ Check base case T(1) = O(1)
→ Inductive step holds for n = 2^k
4. Conclusion:
→ The algorithm runs in \\boxed{O(n\\log n)} time and correctly sorts any input list.
Here is the given task you need to solve:
'''
)
return prompt + question + r'''Your response must start directly with "Solution:" without any preamble. Finish your response immediately after the solution.'''
@PROMPT_REGISTRY.register()
class GeneralQuestionSynthesisPrompt(PromptABC):
'''
The prompt for the question synthesis.
'''
def __init__(self):
pass
def build_prompt(self, items: str, question: str) -> str:
prompt = f"""
Create a new, high‑quality reasoning task from the original by applying some of the following transformations (focus on all transformations of "{items}"):
1. Alter any quantitative or qualitative elements (numbers, dates, variables, data types, code snippets), ensuring the new task remains coherent and solvable.
2. Change the task type or domain: e.g. switch from calculation to proof, from mathematical derivation to algorithm design, from text analysis to code debugging, or vice versa.
3. Reframe the scenario in a different real‑world or abstract context (e.g. finance, engineering, language translation, data processing, robotics), incorporating relevant domain details.
4. Introduce new premises or constraints that require separate consideration or conditional logic in the solution.
5. Increase complexity by adding multiple interdependent steps, branching cases, or requiring integration of diverse skills (e.g. math + coding + reasoning).
6. Vary the output format: require a formal proof, pseudocode, annotated explanation, or numeric answer as appropriate.
Here is the original task:
{question}
Generate a fully self‑contained new task inspired by the above. Start directly with the task statement; do NOT include any framing phrases like “Here is a new task inspired by…”. End your response immediately after the task description.
"""
return prompt
@PROMPT_REGISTRY.register()
class GeneralQuestionFilterPrompt(PromptABC):
def __init__(self):
pass
def build_prompt(self, question: str) -> str:
prompt = f"""You are given a reasoning task. Follow these four steps in order and stop at the first failure:
0. First, verify the input contains only a single clear reasoning task (no extra instructions like “rewrite”, “translate”, or a provided answer); if not, output judgement_test=false.
1. Check spelling, grammar, and formatting (e.g. code indentation, LaTeX, Markdown), without interpreting semantics.
2. For each minimal premise (cannot be further decomposed), verify it does not violate commonsense, domain facts, or task requirements (e.g. “half a person” is invalid; magical operations allowed only if explicitly assumed); if invalid, fail.
3. Check for any contradictions among premises or in the reasoning process, or if the final result is clearly unreasonable or unsolvable; if so, fail.
4. If all above pass, check whether there is enough information to complete the task; missing necessary conditions ⇒ fail, redundant details are acceptable.
After these steps, output exactly:
{{
"judgement_test": true/false,
"error_type": "<error description or null>"
}}
You may include your chain of thought, but the final output must be the JSON above.
Here is the content to evaluate:
-------------------------------
{question}
-------------------------------
"""
return prompt
\ No newline at end of file
from dataflow.utils.registry import PROMPT_REGISTRY
from dataflow.core.prompt import PromptABC
'''
A collection of prompts for the math reasoning operator.
'''
@PROMPT_REGISTRY.register()
class MathAnswerGeneratorPrompt(PromptABC):
'''
The prompt for the answer generator.
'''
def __init__(self):
pass
def build_prompt(self, question: str) -> str:
"""
为给定数学题目生成系统提示信息
"""
prompt = (
r'''You are an intelligent chatbot designed for writing the answer of the given math question.
Remember: DO NOT output anything else, only output the answer you make.
Generate a solution of a given math problem strictly following this format:
1. Identify key components of the problem
2. Apply theorems/formulas with step-by-step derivation
3. Perform calculations with intermediate verification
4. Final answer in \boxed{} notation
Format Requirements:
- Prefix each step with "→" (use the actual arrow symbol, not its Unicode escape sequence)
- Ensure all symbols and special characters are presented using LaTeX commands where appropriate (e.g., ≥ as \\geq, ÷ as \\div)
Example Template:
Problem: Find the minimum value of function f(x) = x³ - 3x² + 4 on interval [-1, 3]
Solution:
1. Find critical points:
→ f'(x) = 3x² - 6x
→ Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2
2. Evaluate function at critical points and endpoints:
→ f(-1) = (-1)^3 - 3(-1)^2 + 4 = -1 -3 +4 = 0.0000
→ f(0) = 0³ - 3(0)² +4 = 4.0000
→ f(2) = 8 - 12 +4 = 0.0000
→ f(3) = 27 - 27 +4 = 4.0000
3. Compare values:
→ Minimum occurs at x=-1 and x=2
Verification:
→ Second derivative test: f''(x) = 6x-6
→ f''(-1) = -12 < 0 (local max)
→ f''(2) = 6 > 0 (local min)
\boxed{0}
Here is the given problem you need to solve:
'''
)
return prompt + question + r'''Your response must directly start with "Solution:" without any preamble, After the answer is generated finish your response right away.'''
@PROMPT_REGISTRY.register()
class MathQuestionSynthesisPrompt(PromptABC):
'''
The prompt for the question synthesis.
'''
def __init__(self):
pass
def build_prompt(self, items: str, question: str) -> str:
prompt = f"""
Create a new reasonable and solvable math problem from the original problem by applying some of the following transformations(focus on all the transformations of "{items}"):
1. Alter numerical values or expressions, ensuring the new problem remains solvable.
2. Modify the problem type: introduce concepts like ratios or percentages, switch between derivatives and integrals, change the question from finding an area to finding a perimeter, etc.
3. Contextualize the problem within a real-world scenario, such as incorporating various payment methods or deferred payments with interest.
4. Add additional premises that require considering an extra factor separately in solving the problem.
5. Increase the complexity of the problem by introducing multiple conditions that necessitate case-by-case analysis for a solution.
Here is the problem from the user:
{question}
Write another problem inspired by this one.
Not only change the problem scenario, but also try to create a new problem that requires another approach to solve.
Start directly with the problem statement and DO NOT include any phrases such as "Here is a new problem inspired by a given one".
After the problem is generated finish your response right away.
"""
return prompt
@PROMPT_REGISTRY.register()
class MathQuestionCategoryPrompt(PromptABC):
'''
The prompt for the question synthesis.
'''
def __init__(self):
pass
def build_prompt(self, question: str) -> str:
prompt = f"""
You are a classification assistant specialized in mathematics. Your task is to classify the given text into one primary category and one secondary category according to the following taxonomy. Do not output any extra explanation. Return only a JSON object with the keys "primary_category" and "secondary_category".
Taxonomy:
1. Foundations and Logic
- 1.1 Mathematical Logic and Set Theory
- 1.2 Basic Theory, Formalization, and History & Education
2. Algebra and Number Theory
- 2.1 Linear Algebra and Group Theory
- 2.2 Ring Theory, Field Theory, and Polynomial Algebra
- 2.3 Commutative Algebra and Homological/Categorical Methods
- 2.4 Number Theory
- 2.5 Algebraic Geometry
3. Analysis and Differential Equations
- 3.1 Real Analysis, Measure Theory, and Functional Analysis
- 3.2 Complex Analysis and Special Functions
- 3.3 Differential Equations and Dynamical Systems
- 3.4 Integral Transforms, Integral Equations, and Difference Equations
- 3.5 Harmonic Analysis
4. Geometry and Topology
- 4.1 Euclidean, Analytic, and Convex/Discrete Geometry
- 4.2 Differential Geometry and Manifold Theory
- 4.3 Topology and Algebraic Topology
5. Probability, Statistics, and Discrete Mathematics
- 5.1 Probability Theory and Stochastic Processes
- 5.2 Mathematical Statistics
- 5.3 Combinatorics and Graph Theory
6. Applied and Computational Mathematics
- 6.1 Numerical Analysis and Computational Methods
- 6.2 Optimal Control, Variational Methods, and Optimization
- 6.3 Operations Research and Game Theory
- 6.4 Systems Theory and Control
- 6.5 Computer Science and Algorithms
- 6.6 Mathematical Physics and Engineering Mathematics
- 6.7 Information and Communication
- 6.8 Biomathematics
7. Arithmetic
- 7.1 Basic Arithmetic and Number Operations
- 7.2 Word Problems and Real-Life Applications
Classify the following text into one primary category and one secondary category based on the taxonomy above. The text is:
{question}
"""
return prompt
@PROMPT_REGISTRY.register()
class MathQuestionDifficultyPrompt(PromptABC):
'''
The prompt for the question synthesis.
'''
def __init__(self):
pass
def build_prompt(self, question: str) -> str:
prompt = r"""
# CONTEXT #
I am a teacher, and I have some high-level olympiad math problems.
I want to evaluate the difficulty of these math problems. There are some references available regarding the difficulty of the problems:
<difficulty reference>
For reference, here are some sample problems from each of the difficulty levels 1-10:
1: Jamie counted the number of edges of a cube, Jimmy counted the numbers of corners, and Judy counted the number of faces. They then added the three numbers. What was the resulting sum? (2003 AMC 8, Problem 1)
1: How many integer values of $x$ satisfy $|x| < 3\pi$? (2021 Spring AMC 10B, Problem 1)
1.5: A number is called flippy if its digits alternate between two distinct digits. For example, $2020$ and $37373$ are flippy, but $3883$ and $123123$ are not. How many five-digit flippy numbers are divisible by $15?$ (2020 AMC 8, Problem 19)
2: A fair $6$-sided die is repeatedly rolled until an odd number appears. What is the probability that every even number appears at least once before the first occurrence of an odd number? (2021 Spring AMC 10B, Problem 18)
2.5: $A$, $B$, $C$ are three piles of rocks. The mean weight of the rocks in $A$ is $40$ pounds, the mean weight of the rocks in $B$ is $50$ pounds, the mean weight of the rocks in the combined piles $A$ and $B$ is $43$ pounds, and the mean weight of the rocks in the combined piles $A$ and $C$ is $44$ pounds. What is the greatest possible integer value for the mean in pounds of the rocks in the combined piles $B$ and $C$? (2013 AMC 12A, Problem 16)
3: Triangle $ABC$ with $AB=50$ and $AC=10$ has area $120$. Let $D$ be the midpoint of $\overline{AB}$, and let $E$ be the midpoint of $\overline{AC}$. The angle bisector of $\angle BAC$ intersects $\overline{DE}$ and $\overline{BC}$ at $F$ and $G$, respectively. What is the area of quadrilateral $FDBG$? (2018 AMC 10A, Problem 24)
3.5: Find the number of integer values of $k$ in the closed interval $[-500,500]$ for which the equation $\log(kx)=2\log(x+2)$ has exactly one real solution. (2017 AIME II, Problem 7)
4: Define a sequence recursively by $x_0=5$ and
\[x_{n+1}=\frac{x_n^2+5x_n+4}{x_n+6}\]
for all nonnegative integers $n.$ Let $m$ be the least positive integer such that
\[x_m\leq 4+\frac{1}{2^{20}}.\]
In which of the following intervals does $m$ lie?
$\textbf{(A) } [9,26] \qquad\textbf{(B) } [27,80] \qquad\textbf{(C) } [81,242]\qquad\textbf{(D) } [243,728] \qquad\textbf{(E) } [729,\infty)$
(2019 AMC 10B, Problem 24 and 2019 AMC 12B, Problem 22)
4.5: Find, with proof, all positive integers $n$ for which $2^n + 12^n + 2011^n$ is a perfect square. (USAJMO 2011/1)
5: Find all triples $(a, b, c)$ of real numbers such that the following system holds:
\[
a+b+c=\frac{1}{a}+\frac{1}{b}+\frac{1}{c},
\]
\[
a^2+b^2+c^2=\frac{1}{a^2}+\frac{1}{b^2}+\frac{1}{c^2}.
\]
(JBMO 2020/1)
5.5: Triangle $ABC$ has $\angle BAC = 60^{\circ}$, $\angle CBA \leq 90^{\circ}$, $BC=1$, and $AC \geq AB$. Let $H$, $I$, and $O$ be the orthocenter, incenter, and circumcenter of $\triangle ABC$, respectively. Assume that the area of pentagon $BCOIH$ is the maximum possible. What is $\angle CBA$? (2011 AMC 12A, Problem 25)
6: Let $\triangle ABC$ be an acute triangle with circumcircle $\omega,$ and let $H$ be the intersection of the altitudes of $\triangle ABC.$ Suppose the tangent to the circumcircle of $\triangle HBC$ at $H$ intersects $\omega$ at points $X$ and $Y$ with $HA=3,\ HX=2,$ and $HY=6.$ The area of $\triangle ABC$ can be written in the form $m\sqrt{n},$ where $m$ and $n$ are positive integers, and $n$ is not divisible by the square of any prime. Find $m+n.$ (2020 AIME I, Problem 15)
6.5: Rectangles $BCC_1B_2,$ $CAA_1C_2,$ and $ABB_1A_2$ are erected outside an acute triangle $ABC.$ Suppose that
\[\angle BC_1C+\angle CA_1A+\angle AB_1B=180^{\circ}.\]
Prove that lines $B_1C_2,$ $C_1A_2,$ and $A_1B_2$ are concurrent. (USAMO 2021/1, USAJMO 2021/2)
7: We say that a finite set $\mathcal{S}$ in the plane is balanced if, for any two different points $A$, $B$ in $\mathcal{S}$, there is a point $C$ in $\mathcal{S}$ such that $AC=BC$. We say that $\mathcal{S}$ is centre-free if for any three points $A$, $B$, $C$ in $\mathcal{S}$, there is no point $P$ in $\mathcal{S}$ such that $PA=PB=PC$.
Show that for all integers $n\geq 3$, there exists a balanced set consisting of $n$ points.
Determine all integers $n\geq 3$ for which there exists a balanced centre-free set consisting of $n$ points.
(IMO 2015/1)
7.5: Let $\mathbb{Z}$ be the set of integers. Find all functions $f : \mathbb{Z} \rightarrow \mathbb{Z}$ such that
\[
xf(2f(y)-x)+y^2f(2x-f(y))=\frac{f(x)^2}{x}+f(yf(y))
\]
for all $x, y \in \mathbb{Z}$ with $x \neq 0$. (USAMO 2014/2)
8: For each positive integer $n$, the Bank of Cape Town issues coins of denomination $\frac1n$. Given a finite collection of such coins (of not necessarily different denominations) with total value at most $99+\frac{1}{2}$, prove that it is possible to split this collection into $100$ or fewer groups, such that each group has total value at most $1$. (IMO 2014/5)
8.5: Let $I$ be the incentre of acute triangle $ABC$ with $AB\neq AC$. The incircle $\omega$ of $ABC$ is tangent to sides $BC, CA$, and $AB$ at $D, E,$ and $F$, respectively. The line through $D$ perpendicular to $EF$ meets $\omega$ at $R$. Line $AR$ meets $\omega$ again at $P$. The circumcircles of triangle $PCE$ and $PBF$ meet again at $Q$.
Prove that lines $DI$ and $PQ$ meet on the line through $A$ perpendicular to $AI$. (IMO 2019/6)
9: Let $k$ be a positive integer and let $S$ be a finite set of odd prime numbers. Prove that there is at most one way (up to rotation and reflection) to place the elements of $S$ around the circle such that the product of any two neighbors is of the form $x^2+x+k$ for some positive integer $x$. (IMO 2022/3)
9.5: An anti-Pascal triangle is an equilateral triangular array of numbers such that, except for the numbers in the bottom row, each number is the absolute value of the difference of the two numbers immediately below it. For example, the following is an anti-Pascal triangle with four rows which contains every integer from $1$ to $10$.
\[
\begin{array}{ c@{\hspace{4pt}}c@{\hspace{4pt}} c@{\hspace{4pt}}c@{\hspace{2pt}}c@{\hspace{2pt}}c@{\hspace{4pt}}c }
& & & 4 & & & \\
& & 2 & & 6 & & \\
& 5 & & 7 & & 1 & \\
8 & & 3 & & 10 & & 9 \\
\end{array}
\]
Does there exist an anti-Pascal triangle with $2018$ rows which contains every integer from $1$ to $1 + 2 + 3 + \dots + 2018$? (IMO 2018/3)
10: Prove that there exists a positive constant $c$ such that the following statement is true: Consider an integer $n > 1$, and a set $\mathcal S$ of $n$ points in the plane such that the distance between any two different points in $\mathcal S$ is at least 1. It follows that there is a line $\ell$ separating $\mathcal S$ such that the distance from any point of $\mathcal S$ to $\ell$ is at least $cn^{-1/3}$.
(A line $\ell$ separates a set of points S if some segment joining two points in $\mathcal S$ crosses $\ell$.) (IMO 2020/6)
## Some known difficulty ratings of the competitions.
</difficulty reference>
# OBJECTIVE #
A. Summarize the math problem in a brief sentence, describing the concepts involved in the math problem.
B. Based on the source of the given problem, as well as the difficulty of the problems referenced in these materials and the solution to the current problem, please provide
an overall difficulty score for the current problem. The score should be a number between 1 and 10, with increments of 0.5, and should align perfectly with the materials.
# STYLE #
Data report.
# TONE #
Professional, scientific.
# AUDIENCE #
Students. Enable them to better understand the difficulty of the math problems.
# RESPONSE: MARKDOWN REPORT #
## Summarization
[Summarize the math problem in a brief paragraph.]
## Difficulty
[Rate the difficulty of the math problem and give the reason.]
# ATTENTION #- Add "=== report over ===" at the end of the report.
<example math problem>
The problem requires finding the missing value in the equation
\[
\frac{1}{9}+\frac{1}{18}=\frac{1}{\square}.
\]
In other words, determine the number that should replace the square such that the sum of the fractions on the left equals the fraction on the right.
</example math problem>
## Summarization
The problem requires finding a value that makes the equation $\\frac{1}{9}+\\frac{1}{18}=\\frac{1}{\\square}$.
This involves adding two fractions and determining the equivalent fraction.
## Difficulty
Rating: 1
Reason: This problem is straightforward and primarily involves basic fraction addition, making it suitable for early middle school students.
=== report over ===
</example math problem>
Let $\mathcal{P}$ be a convex polygon with $n$ sides, $n\ge3$. Any set of $n - 3$ diagonals of $\mathcal{P}$ that do not intersect in the interior of the polygon determine a triangulation of $\mathcal{P}$ into $n - 2$ triangles. If $\mathcal{P}$ is regular and there is a triangulation of $\mathcal{P}$ consisting of only isosceles triangles, find all the possible values of $n$.
</example math problem>
## Summarization
The problem asks for the possible values of $n$ for a regular n-sided polygon that can be completely triangulated into isosceles triangles using non-intersecting diagonals.
The solution involves analyzing the properties of the diagonals forming isosceles triangles and deducing that $n$ can be expressed in terms of powers of 2.
## Difficulty
Rating: 7
Reason: The problem involves understanding properties of isosceles triangles in the context of polygon triangulation and requires critical reasoning to establish
relationships between the number of sides and powers of 2, making it more complex than typical undergraduate-level problems.
=== report over ===
<math problem>
[Question]: \n
"""
return prompt + question
@PROMPT_REGISTRY.register()
class MathQuestionFilterPrompt(PromptABC):
'''
The prompt for the question filter.
'''
def __init__(self):
pass
def build_prompt(self, question: str) -> str:
"""Constructs an evaluation prompt with four progressive checks"""
prompt = f"""You are given a mathematical problem. Follow these four steps in order and stop at the first failure:
0. Firstly check if it is only a math problem, if it has other instruction confused the model such as "rewrite" or has answer or other strange instruction, then judged as failure. If it is not a math problem, then the judgement_test is false.
1. Check only for spelling, grammar, and LaTeX formatting correctness. Do not interpret semantic meaning.
2. For each minimal condition stated in the problem (that cannot be further decomposed), check if it violates the mathematical domain or objective facts (for example, 'half a person' is incorrect). Note: Magical operations are acceptable if the necessary assumption is explicitly stated. Average values (e.g., 15.5 items per minute) are acceptable.
3. Check whether the problem-solving process contains any contradictions. This includes any two minimal conditions contradicting each other or if the final solution would be unreasonable (including unsolvable).
4. If the steps above pass, check if there are enough conditions provided in the problem to answer the target question. Redundant conditions that do not affect the problem - solving process are considered reasonable. Both analytical and numerical solutions are considered valid unless otherwise specified.
After performing these steps in sequence, output your final judgment in JSON format with exactly the following keys:
{{
"judgement_test": true/false,
"error_type": "<error description or null>"
}}
You may include your chain-of-thought, but the final answer must be the JSON object above.
Here is the problem to evaluate:
-------------------------------
{question}
-------------------------------
"""
return prompt
@PROMPT_REGISTRY.register()
class MathQuestionSequentialFusionGeneratorPrompt(PromptABC):
def __init__(self):
pass
def build_system_prompt(self):
system_prompt = ""
return system_prompt
def build_prompt(self, input_question_1, input_question_2):
prompt = f"""
# Role: Mathematical Problem Merger
## Profile
Your role is to merge "#Problem 1#" and "#Problem 2#" into a combined problem.
## Guidelines
Step 1: Identify input and output variables in both problems. Determine mathematical relationships and constraints in each
problem. Locate variables between "#Problem 1#" and "#Problem 2#" that can form sequential dependencies.
Step 2: Formulate a comprehensive plan to merge the two problems by using "#Problem 1#"’s output variable to
replace an input variable of "#Problem 2#"’s. Merge contextual elements by embedding both problems within a unified
real-world scenario or extended narrative, aligning units and measurement systems.
Step 3: Create a single "#New Problem#" where solving "#Problem 1#" is a prerequisite for "#Problem
## Output Format
Please reply strictly in the following format:
#Elements Identified#:
#Plan#:
#New Problem#:
## Input
### #Problem 1#
{input_question_1}
### #Problem 2#
{input_question_2}
2#". Explicitly state variable dependencies and which variable is replaced. Adjust numerical ranges to maintain arithmetic
consistency. The "#New Problem#" should contain no supplementary explanation or note.
## Output
"""
return prompt
@PROMPT_REGISTRY.register()
class MathQuestionParallelFusionGeneratorPrompt(PromptABC):
def __init__(self):
pass
def build_system_prompt(self):
system_prompt = ""
return system_prompt
def build_prompt(self, input_question_1, input_question_2):
prompt = f"""
# Role: Mathematical Problem Synthesizer
## Profile Your role is to organically integrate "#Problem 1#" and "#Problem 2#" to create a novel problem that
requires advanced synthesis of their mathematical essence.
## Guidelines
Step 1: Conduct deep structural analysis of both problems by identifying their fundamental mathematical operations,
contextual frameworks, and cognitive patterns. Extract the underlying logical architectures while preserving their distinctive
solution pathways.
Step 2: Develop an innovative fusion mechanism by discovering non-obvious mathematical connections between
the problems’ core concepts. Construct a multidimensional scenario that naturally embeds both original contexts through
temporal sequencing, spatial superposition, or conceptual analogy. Engineer hybrid parameters that inherit characteristics
from both source problems while introducing emergent properties.
Step 3: Formulate the synthesized problem through strategic recombination of mathematical elements, ensuring
the new problem requires concurrent application of both original solution strategies. Introduce controlled complexity
problems’ answers.
## Output Format
Please reply strictly in the following format:
#Core Elements#:
#Synthesis Method#:
#New Problem#:
## Input
### #Problem 1#
{input_question_1}
### #Problem 2#
{input_question_2}
through cross-domain constraints and self-verification mechanisms that establish mathematical consistency with both source
## Output
"""
return prompt
@PROMPT_REGISTRY.register()
class MathQuestionConditionFusionGeneratorPrompt(PromptABC):
def __init__(self):
pass
def build_system_prompt(self):
system_prompt = ""
return system_prompt
def build_prompt(self, input_question_1, input_question_2):
prompt = f"""
# Role: Problem Integrator
## Profile
Create a real-world problem where the solution requires solving both "#Problem 1#" and "#Problem 2#" independently.
**Ensure the the final answer is either from "#Problem 1#" or "#Problem 2#", depends on the "#New Question#"**.
## Guidelines
Step 1: Analyze "#Problem 1#" and "#Problem 2#" and make sure that the output variables they ask about are of the same
type. If they are different (for example, one asks about time and the other asks about price), modify one of the problem so that
it asks about the same variable as the other.
Step 2: Design a unified problem scenario that combines "#Problem 1#" and "#Problem 2#". Introduce a "#New Question#",
which must be related with both "#Problem 1#" and "#Problem 2#". Ensure that final answer of the "#New Question#" must
either come from "#Problem 1#" or "#Problem 2#". This means that the "#New Question#" should be an **comparison**
and **selection** of the previous answers, not their **combination**. There are some examples for the "#New Question#":
1. Who sells the most items?
2. Howmuch money does the top earner make?
3. Which is the cheaper plan?
4. Someone has 200 dollor, which item can he afford?
phrases "#Problem 1#" and "#Problem 2#" in the generated "#New Problem#".
## Output Format
Please reply strictly in the following format:
#Analysis#:
#New Question#:
#New Problem#:
## Input
### #Problem 1#
{input_question_1}
### #Problem 2#
{input_question_2}
Step 3: Provide the "#New Problem#", which combine "#Problem 1#", "#Problem 2#", and "#New Question#" in a unified
real-world scenario. Don’t contain solution of "#Problem 1#" and "#Problem 2#" in "#New Problem#".
## Output
"""
return prompt
@PROMPT_REGISTRY.register()
class MathQuestionEvaluatorPrompt(PromptABC):
def __init__(self):
pass
def build_system_prompt(self):
system_prompt = ""
return system_prompt
def build_prompt(self, input_question):
prompt = f"""
# Role: Mathematics Grading Teacher
## Profile
You are a senior mathematics grading teacher in university, very skilled in high difficulty fields such as Intermediate Algebra,
Precalculus, Prealgebra, Number Theory, Geometry, Counting & Probability, Algebra and so on.
## Guidelines
Your task is to act as an impartial judge to evaluate the statement completeness and correctness of math problem according to
the following rules:
1. Assess the clarity and accuracy of the definition of each math problem. Ensure that the problem statement provides
sufficient information, conditions, and constraints.
2. Consider whether the problem allows for multiple interpretations or if further clarification is needed.
3. Evaluate the clarity of mathematical notation and terminology used in the problem.
## Output Format
Please reply strictly in the following format:
#Judgement#:
#Explanation#:
## Input
{input_question}
4. Evaluate whether the math problem is solvable. If the math problem meet the rules above, output "True" in "#Judge
ment#", else "False". You should also give your explanation in "#Explanation#".
## Output
"""
return prompt
\ No newline at end of file
import json
from dataflow.utils.registry import PROMPT_REGISTRY
from dataflow.core.prompt import PromptABC
"""
A collection of prompts for the Text2QA pipelines operator
"""
@PROMPT_REGISTRY.register()
class Text2QAAutoPromptGeneratorPrompt(PromptABC):
'''
The prompt for the AutoPromptGenerator.
'''
def __init__(self):
pass
def build_prompt(self, seed_data: str) -> str:
prompt = f'''You will be given a piece of seed data, which may consist of a paragraph, dialogue, or any other form of text containing potential question-answer information.
Your task is to analyze this seed data carefully and generate as much non-repeat clear and effective prompt as you can that can be used to instruct a language model to extract a single high-quality question-answer (QA) pair suitable for reinforcement learning (RL) training from this piece of data.
The generated prompt should:
Clearly describe the type and format of input the model will receive;
Explicitly ask for the extraction of a relevant QA pair;
Optionally include instructions about the desired style, level of detail, or coverage;
Be written in natural, precise English that could be directly used with another LLM;
Be strictly the prompt used to extract QA pairs, not the QA pairs themselves.
Your prompts should contain the following instructions:
The question should be clear, focused, and unambiguous, such that it targets specific factual content from the input;
The answer should be a few words that are concise, factual and directly verifiable from the source rather than a whole sentence, enabling accurate reward computation in the RL pipeline;
Both the question and answer should be simple enough to facilitate evaluation and automatic feedback.
Don't include any additional explanations or comments in your output.
Don't repeat the seed data in your output.
Your output format should be in a list as follow:
["PROMPT_1","PROMPT_2",...]
Here is the seed data you need to analyze and generate a prompt for:\n{seed_data}'''
return prompt
@PROMPT_REGISTRY.register()
class Text2QASeedQuestionGeneratorPrompt(PromptABC):
'''
The prompt for the Text2QAGenerator.
'''
def __init__(self):
pass
def build_prompt(self) -> str:
prompt = f'''"Format:\nQ: ...\nA: ..." + "\nSeed data:\n"'''
return prompt
@PROMPT_REGISTRY.register()
class Text2QAQuestionQualityPrompt(PromptABC):
'''
The prompt for the question quality scorer.
'''
def __init__(self):
pass
def build_prompt(self) -> str:
prompt = '''You are an expert question quality evaluator. Given a single question from a QA dataset, your job is to assess the **clarity and meaningfulness** of the question. Specifically, judge whether the question is clearly defined, unambiguous, and worth asking in a real-world or task-specific context.
Assign a score from 1 to 5 based on the following rubric:
5 = Very clear and meaningful question, well-posed
4 = Clear but slightly underspecified or too general
3 = Somewhat unclear or poorly scoped, but understandable
2 = Ambiguous, vague, or unnatural
1 = Nonsensical or meaningless
Output format:
**Grading**: [1-5]
**Feedback**: Explain your score. Mention if the question is ambiguous, overly broad, or lacks practical purpose. Suggest how to improve clarity or specificity if needed.
'''
return prompt
@PROMPT_REGISTRY.register()
class Text2QAAnswerAlignmentPrompt(PromptABC):
'''
The prompt for the RAG answer alignment scorer.
'''
def __init__(self):
pass
def build_prompt(self) -> str:
prompt = '''You are a response alignment evaluator. Your task is to assess whether a given answer **directly and clearly addresses the given question**.
Assign a score from 1 to 5 based on the following rubric:
5 = Fully and directly answers the question
4 = Mostly addresses the question, with minor gaps or irrelevant additions
3 = Partially answers the question but omits key aspects
2 = Barely addresses the question or is off-topic
1 = Completely unrelated to the question
Output format:
**Grading**: [1-5]
**Feedback**: Justify your score. Point out if the answer is evasive, incomplete, or misaligned. Suggest ways to better match the response to the question.
'''
return prompt
@PROMPT_REGISTRY.register()
class Text2QAAnswerVerifiabilityPrompt(PromptABC):
'''
The prompt for the RAG answer verifiability scorer.
'''
def __init__(self):
pass
def build_prompt(self) -> str:
prompt = '''You are an evaluator tasked with assessing how **easily verifiable** an answer is. You must determine whether the correctness of the answer can be **conveniently and unambiguously judged** — for example, whether it is fact-based, precise, and not subjective or vague.
Assign a score from 1 to 5 based on the following rubric:
5 = Very easy to verify; answer is objective, concrete, and unambiguous
4 = Mostly verifiable, with minor ambiguities
3 = Verifiable in parts, but some subjectivity or fuzziness
2 = Hard to verify; answer is vague, speculative, or opinion-based
1 = Unverifiable or meaningless
Output format:
**Grading**: [1-5]
**Feedback**: Explain your score. Identify elements that make verification easier or harder. Suggest rephrasing or grounding techniques to improve verifiability.
'''
return prompt
@PROMPT_REGISTRY.register()
class Text2QADownstreamValuePrompt(PromptABC):
'''
The prompt for the RAG downstream value scorer.
'''
def __init__(self):
pass
def build_prompt(self) -> str:
prompt = '''You are a task relevance evaluator. Given a QA pair, assess how well this data point could **support a downstream task** such as classification, dialogue, retrieval, summarization, or knowledge grounding.
Assign a score from 1 to 5 based on the following rubric:
5 = Highly valuable for downstream tasks; question and answer are precise and informative
4 = Useful with minor limitations
3 = Moderately helpful; limited in informativeness or specificity
2 = Of little value; vague or too generic to help the model learn
1 = Useless or irrelevant for any downstream learning objective
Output format:
**Grading**: [1-5]
**Feedback**: Describe how the QA pair does or does not benefit potential downstream tasks. If relevant, suggest how to make it more useful for training.
'''
return prompt
import textwrap
from typing import Dict, Literal
from dataflow.utils.registry import PROMPT_REGISTRY
from dataflow.core.prompt import PromptABC
@PROMPT_REGISTRY.register()
class Text2MultiHopQAGeneratorPrompt(PromptABC):
'''
多跳问答生成器(严格JSON格式输出)
根据语言参数提供完全独立的专业提示模板
'''
def __init__(self, lang: str = "en"):
self.lang = lang
self.system_text = self.build_system_prompt()
def build_system_prompt(self) -> str:
"""构建专业级多跳问答提示"""
if self.lang == "en":
return textwrap.dedent("""\
You are a professional multi-hop QA specialist with strict protocols:
█ Core Requirements
1. Must identify 2-3 interrelated facts in context
2. Design complex questions requiring cross-fact reasoning
3. Reasoning chains must:
- Contain 2-3 logical steps (numbered)
- Show clear causal/progressive relationships
- Each step must reference specific facts
4. Final answer must synthesize all reasoning conclusions
5. Focus solely on the main text and avoid synthesizing Q&A based on content found in links, references, or other supplementary sources.
█ Output Specifications
1. Only pure JSON in this structure:
{
"question": "Multi-fact reasoning question",
"reasoning_steps": [
{"step": "First step (must use Fact 1)"},
{"step": "Second step (must link Fact 2)"}
],
"answer": "Synthesized final answer",
"supporting_facts": ["Verbatim Fact 1", "Verbatim Fact 2"],
"type": "domain_tag"
}
2. Supporting facts must:
- Be verbatim from context
- Directly support corresponding steps
- No paraphrasing allowed
█ Demonstration
Context:
"Photosynthesis converts CO2 to oxygen. This process sustains plant growth. Plants form the base of food chains."
Valid Output:
{
"question": "How does photosynthesis impact ecosystems?",
"reasoning_steps": [
{"step": "Photosynthesis produces oxygen"},
{"step": "Plants using photosynthesis form food chain bases"}
],
"answer": "It provides oxygen and sustains ecosystem food chains",
"supporting_facts": [
"Photosynthesis converts CO2 to oxygen",
"Plants form the base of food chains"
],
"type": "biology"
}
█ Rejection Criteria
Reject if:
- Fewer than 2 reasoning steps
- Unreferenced supporting facts exist
- Any non-JSON content appears
""")
else:
return textwrap.dedent("""\
您是专业的多跳问答生成专家,必须严格遵循以下专业标准:
█ 核心要求
1. 必须识别上下文中的2-3个关联事实
2. 设计需要跨事实推理的复杂问题
3. 推理链必须满足:
- 至少包含2-3个逻辑步骤
- 每个步骤明确标注序号
- 步骤间存在因果或递进关系
4. 最终答案必须整合所有推理结论
5. 只关注正文内容,避免根据链接、参考文献等附加信息合成问答。
█ 输出规范
1. 仅允许输出以下结构的纯JSON:
{
"question": "需要跨事实推理的问题",
"reasoning_steps": [
{"step": "第一推理步骤(必须引用事实1)"},
{"step": "第二推理步骤(必须关联事实2)"}
],
"answer": "整合所有步骤的最终答案",
"supporting_facts": ["原文事实1", "原文事实2"],
"type": "领域标签"
}
2. 支撑事实必须:
- 从上下文逐字提取
- 与推理步骤严格对应
- 不得改写或概括
█ 示例
上下文:
"量子纠缠现象由爱因斯坦提出质疑。后来贝尔实验证实了其真实性。该现象是量子计算的基础。"
合格输出:
{
"question": "为什么量子纠缠现象对量子计算很重要?",
"reasoning_steps": [
{"step": "贝尔实验证实了量子纠缠的真实性"},
{"step": "该现象是量子计算的基础"}
],
"answer": "因为量子纠缠被证实真实且是量子计算的基础",
"supporting_facts": [
"后来贝尔实验证实了其真实性",
"该现象是量子计算的基础"
],
"type": "量子物理"
}
█ 违规处理
以下情况将拒绝输出:
- 推理步骤少于2步
- 存在未引用的支撑事实
- JSON外出现任何附加文本
""")
def build_prompt(self, text: str) -> str:
"""生成完全专业化的用户提示"""
if self.lang == "en":
user_prompt = textwrap.dedent(f"""\
Generate professional multi-hop QA from:
Context:
{text}
Strict requirements:
1. Extract exactly 2-3 interrelated facts
2. Question must demonstrate cross-fact reasoning
3. Use this exact JSON structure (include all quotes/braces):
{{
"question": "...",
"reasoning_steps": [
{{"step": "Must explicitly use Fact 1"}},
{{"step": "Must explicitly link Fact 2"}}
],
"answer": "...",
"supporting_facts": ["Verbatim Fact 1", "Verbatim Fact 2"],
"type": "..."
}}
""")
else:
user_prompt = textwrap.dedent(f"""\
请基于以下上下文生成专业级多跳问答:
上下文:
{text}
严格按照以下要求执行:
1. 必须从上述上下文中提取2-3个关联事实
2. 问题需体现跨事实推理的复杂性
3. 使用此精确JSON结构(包括所有引号和括号):
{{
"question": "...",
"reasoning_steps": [
{{"step": "必须明确引用事实1"}},
{{"step": "必须明确关联事实2"}}
],
"answer": "...",
"supporting_facts": ["事实1原文", "事实2原文"],
"type": "..."
}}
""")
return user_prompt
\ No newline at end of file
'''
A collection of prompts for the text2sql operator.
'''
import random
from re import template
import numpy as np
import json
from typing import List
from dataflow.utils.registry import PROMPT_REGISTRY
from dataflow.core.prompt import PromptABC
@PROMPT_REGISTRY.register()
class SQLConsistencyFilterPrompt(PromptABC):
def __init__(self):
pass
def build_prompt(self, question: str, sql: str, db_details: str) -> str:
prompt = f"""
**Task Overview**
Determine if the SQL query correctly answers the given question based on the provided schema.
**Question**
{question}
**SQL**
{sql}
**Schema**
{db_details}
**Evaluation Criteria**
1. **Logical Alignment**: Does the SQL query logically address what the question is asking?
2. **Schema Compliance**: Are the tables, columns, and relationships used correctly according to the schema?
3. **Completeness**: Does the SQL capture all necessary conditions and requirements from the question?
4. **Correctness**: Are there any logical errors that would prevent getting the correct answer?
**Output Format**:
The conclusion should be enclosed in a code block:
```
<Conclusion> YES/NO </Conclusion>
```
**Decision Rules**
- YES: SQL correctly implements the question requirements
- NO: SQL has logical errors or doesn't address the question properly
- When uncertain about edge cases, explain the uncertainty in analysis but still provide a definitive YES/NO
**Answer**
Let's proceed step by step.
"""
return prompt
@PROMPT_REGISTRY.register()
class Text2SQLCotGeneratorPrompt(PromptABC):
def __init__(self):
pass
def build_prompt(self, schema_str: str, question: str, sql: str, evidence: str) -> str:
if evidence:
question_with_evidence = question + "\n" + evidence
else:
question_with_evidence = question
prompt = f"""
You are a senior data analyst specializing in SQL. Your task is to translate a natural language question into an executable SQLite query, providing a detailed reasoning trace.
You will also receive a reference solution from a colleague, which may or may not be correct. This extra information intends to help you generate your answer, but you are asked not to mention the reference solution in any form.
The reference solution might include:
1. Unnecessary table and column selections.
2. Incorrect or excessive joins.
3. Misalignment with the question.
4. Opportunities for simplification.
Ensure the SQL query is presented in a Markdown code block with proper syntax highlighting, like this:
```sql
SELECT * FROM table;
```
[Database Schema]:
{schema_str}
[Natural Language Question]:
{question_with_evidence}
[Reference Solution]:
```sql
{sql}
```
Provide your step-by-step text-to-SQL solution here.
"""
return prompt
@PROMPT_REGISTRY.register()
class SelectSQLGeneratorPrompt(PromptABC):
def __init__(self):
self.simple_criterion = '''**Criteria:**
Simple SQL queries may satisfy one or more of the following criteria:
- Simple queries should select data from a single table only.
- Basic aggregate functions are permitted, such as `COUNT`, `SUM`, `AVG`, `MIN`, `MAX`.
- No joins are allowed; the query must operate on a single table.
**Example of Simple SQL Query:**
```sql
SELECT name, department_name
FROM employees
WHERE level > 5
ORDER BY age DESC;
```'''
self.moderate_criterion = '''**Criteria:**
Moderate SQL queries may satisfy one or more of the following criteria:
- Involves table joins, such as `JOIN`, `INNER JOIN`, `LEFT JOIN`, `CROSS JOIN`, etc.
- Includes subqueries within the `SELECT` or `WHERE` clauses.
- Utilizes aggregate functions alongside a `GROUP BY` clause.
- Contains complex `WHERE` conditions, including `IN`, `BETWEEN`, `LIKE`.
- Incorporate a `HAVING` clause to filter aggregated results.
- Uses aggregate functions like `COUNT`, `SUM`, `AVG`, `MIN`, `MAX`, etc.
**Example of Moderate SQL Query:**
```sql
SELECT e.name, d.department_name, AVG(s.salary) AS average_salary
FROM employees e
INNER JOIN departments d ON e.department_id = d.department_id
LEFT JOIN salaries s ON e.employee_id = s.employee_id
WHERE e.age > 30 AND e.status = 'active'
GROUP BY e.name, d.department_name
HAVING AVG(s.salary) > 50000;
```'''
self.complex_criterion = '''**Criteria:**
Complex SQL queries may satisfy one or more of the following criteria:
- Contains complex nested subqueries.
- Utilizes multiple types of joins, including self-joins.
- Includes window functions, such as `ROW_NUMBER`, `RANK`, etc.
- Uses Common Table Expressions (CTEs) for improved readability.
- Combines multiple aggregate functions.
- Involves complex `WHERE` and `HAVING` clauses with multiple conditions.
- Utilizes advanced functions and operators.
**Example of Complex SQL Query:**
```sql
WITH EmployeeCTE AS (
SELECT employee_id, name, department_id, ROW_NUMBER() OVER (PARTITION BY department_id ORDER BY salary DESC) AS rank
FROM employees
)
SELECT e.name, d.department_name
FROM EmployeeCTE e
INNER JOIN departments d ON e.department_id = d.department_id
WHERE e.rank <= 3;
```'''
self.highly_complex_criterion = '''**Criteria:**
Highly complex SQL queries may satisfy one or more of the following criteria:
- Includes multiple Common Table Expressions (CTEs) for readability.
- Combines nested subqueries and various joins.
- Utilizes recursive CTEs for hierarchical or recursive queries.
- Extensively uses advanced window functions.
- May involve `UNION` or `UNION ALL` to combine result sets.
- Implements complex logic with advanced analytical functions.
- Employs a wide range of SQL clauses and conditions.
- Utilizes a broad spectrum of SQL functions and advanced features.
**Example of Highly Complex SQL Query:**
```sql
WITH RECURSIVE EmployeeHierarchy AS (
SELECT employee_id, name, manager_id, department_id, 1 as level
FROM employees
WHERE manager_id IS NULL
UNION ALL
SELECT e.employee_id, e.name, e.manager_id, e.department_id, eh.level + 1
FROM employees e
JOIN EmployeeHierarchy eh ON e.manager_id = eh.employee_id
),
DepartmentSalaries AS (
SELECT eh.employee_id, eh.name, eh.level, d.department_name, s.salary, d.department_id
FROM EmployeeHierarchy eh
INNER JOIN departments d ON eh.department_id = d.department_id
INNER JOIN salaries s ON eh.employee_id = s.employee_id
),
DepartmentStats AS (
SELECT
d.department_id,
COUNT(e.employee_id) AS employee_count,
AVG(s.salary) AS average_salary
FROM employees e
INNER JOIN salaries s ON e.employee_id = s.employee_id
INNER JOIN departments d ON e.department_id = d.department_id
GROUP BY d.department_id
)
SELECT ds.name, ds.level,
SUM(ds.salary) OVER (PARTITION BY ds.department_id ORDER BY ds.level, ds.name) AS cumulative_salary
FROM DepartmentSalaries ds
INNER JOIN DepartmentStats dstat ON ds.department_id = dstat.department_id
ORDER BY ds.level, ds.name;
```'''
self.complexity2criterion = {
"Simple": self.simple_criterion,
"Moderate": self.moderate_criterion,
"Complex": self.complex_criterion,
"Highly Complex": self.highly_complex_criterion
}
self.complexity2criterion = {
"Simple": self.simple_criterion,
"Moderate": self.moderate_criterion,
"Complex": self.complex_criterion,
"Highly Complex": self.highly_complex_criterion
}
self.functions = [
"ABS(X) \nDescription: The ABS(X) function returns the absolute value of the numeric argument X. Abs(X) returns NULL if X is NULL. Abs(X) returns 0.0 if X is a string or blob that cannot be converted to a numeric value. If X is the integer -9223372036854775808 then ABS(X) throws an integer overflow error since there is no equivalent positive 64-bit two complement value. ",
"CHANGES() \nDescription: The CHANGES() function returns the number of database rows that were changed or inserted or deleted by the most recently completed INSERT, DELETE, or UPDATE statement, exclusive of statements in lower-level triggers. The CHANGES() SQL function is a wrapper around thesqlite3_changes64()C/C++ function and hence follows the same rules for counting changes. ",
"CHAR(X1,X2,...,XN) \nDescription: The CHAR(X1,X2,...,XN) function returns a string composed of characters having the unicode code point values of integers X1 through XN, respectively. ",
"COALESCE(X,Y,...) \nDescription: The COALESCE() function returns a copy of its first non-NULL argument, or NULL if all arguments are NULL. Coalesce() must have at least 2 arguments. ",
"CONCAT(X,...) \nDescription: The CONCAT(...) function returns a string which is the concatenation of the string representation of all of its non-NULL arguments. If all arguments are NULL, then CONCAT() returns an empty string. ",
"CONCAT_WS(SEP,X,...) \nDescription: The CONCAT_WS(SEP,...) function returns a string that is the concatenation of all non-null arguments beyond the first argument, using the text value of the first argument as a separator. If the first argument is NULL, then CONCAT_WS() returns NULL. If all arguments other than the first are NULL, then CONCAT_WS() returns an empty string. ",
"FORMAT(FORMAT,...) \nDescription: The FORMAT(FORMAT,...) SQL function works like thesqlite3_mprintf()C-language function and the printf() function from the standard C library. The first argument is a format string that specifies how to construct the output string using values taken from subsequent arguments. If the FORMAT argument is missing or NULL then the result is NULL. The %n format is silently ignored and does not consume an argument. The %p format is an alias for %X. The %z format is interchangeable with %s. If there are too few arguments in the argument list, missing arguments are assumed to have a NULL value, which is translated into 0 or 0.0 for numeric formats or an empty string for %s. See thebuilt-in printf()documentation for additional information. ",
"GLOB(X,Y) \nDescription: The GLOB(X,Y) function is equivalent to the expression \"Y GLOB X\". Note that the X and Y arguments are reversed in the GLOB() function relative to the infixGLOBoperator. Y is the string and X is the pattern. So, for example, the following expressions are equivalent:name GLOB '*helium*' GLOB('*helium*',name)If thesqlite3_create_function()interface is used to override the GLOB(X,Y) function with an alternative implementation then theGLOBoperator will invoke the alternative implementation. ",
"HEX(X) \nDescription: The HEX() function interprets its argument as a BLOB and returns a string which is the upper-case hexadecimal rendering of the content of that blob.If the argumentXin \"hex(X)\" is an integer or floating point number, then \"interprets its argument as a BLOB\" means that the binary number is first converted into a UTF8 text representation, then that text is interpreted as a BLOB. Hence, \"hex(12345678)\" renders as \"3132333435363738\" not the binary representation of the integer value \"0000000000BC614E\".See also:unhex() ",
"IFNULL(X,Y) \nDescription: The IFNULL() function returns a copy of its first non-NULL argument, or NULL if both arguments are NULL. Ifnull() must have exactly 2 arguments. The IFNULL() function is equivalent tocoalesce()with two arguments. ",
"IIF(X,Y,Z) \nDescription: The IIF(X,Y,Z) function returns the value Y if X is true, and Z otherwise. The IIF(X,Y,Z) function is logically equivalent to and generates the samebytecodeas theCASE expression\"CASE WHEN X THEN Y ELSE Z END\". ",
"INSTR(X,Y) \nDescription: The INSTR(X,Y) function finds the first occurrence of string Y within string X and returns the number of prior characters plus 1, or 0 if Y is nowhere found within X. Or, if X and Y are both BLOBs, then INSTR(X,Y) returns one more than the number bytes prior to the first occurrence of Y, or 0 if Y does not occur anywhere within X. If both arguments X and Y to INSTR(X,Y) are non-NULL and are not BLOBs then both are interpreted as strings. If either X or Y are NULL in INSTR(X,Y) then the result is NULL. ",
"LAST_INSERT_ROWID() \nDescription: The LAST_INSERT_ROWID() function returns theROWIDof the last row insert from the database connection which invoked the function. The LAST_INSERT_ROWID() SQL function is a wrapper around thesqlite3_last_insert_rowid()C/C++ interface function. ",
"LENGTH(X) \nDescription: For a string value X, the LENGTH(X) function returns the number of characters (not bytes) in X prior to the first NUL character. Since SQLite strings do not normally contain NUL characters, the LENGTH(X) function will usually return the total number of characters in the string X. For a blob value X, LENGTH(X) returns the number of bytes in the blob. If X is NULL then LENGTH(X) is NULL. If X is numeric then LENGTH(X) returns the length of a string representation of X.Note that for strings, the LENGTH(X) function returns thecharacterlength of the string, not the byte length. The character length is the number of characters in the string. The character length is always different from the byte length for UTF-16 strings, and can be different from the byte length for UTF-8 strings if the string contains multi-byte characters. Use theoctet_length()function to find the byte length of a string.For BLOB values, LENGTH(X) always returns the byte-length of the BLOB.For string values, LENGTH(X) must read the entire string into memory in order to compute the character length. But for BLOB values, that is not necessary as SQLite knows how many bytes are in the BLOB. Hence, for multi-megabyte values, the LENGTH(X) function is usually much faster for BLOBs than for strings, since it does not need to load the value into memory. ",
"LIKE(X,Y) or LIKE(X,Y,Z) \nDescription: The LIKE() function is used to implement the \"Y LIKE X [ESCAPE Z]\" expression. If the optional ESCAPE clause is present, then the LIKE() function is invoked with three arguments. Otherwise, it is invoked with two arguments only. Note that the X and Y parameters are reversed in the LIKE() function relative to the infixLIKEoperator. X is the pattern and Y is the string to match against that pattern. Hence, the following expressions are equivalent:name LIKE '%neon%' LIKE('%neon%',name)Thesqlite3_create_function()interface can be used to override the LIKE() function and thereby change the operation of theLIKEoperator. When overriding the LIKE() function, it may be important to override both the two and three argument versions of the LIKE() function. Otherwise, different code may be called to implement theLIKEoperator depending on whether or not an ESCAPE clause was specified. ",
"LIKELIHOOD(X,Y) \nDescription: The LIKELIHOOD(X,Y) function returns argument X unchanged. The value Y in LIKELIHOOD(X,Y) must be a floating point constant between 0.0 and 1.0, inclusive. The LIKELIHOOD(X) function is a no-op that the code generator optimizes away so that it consumes no CPU cycles during run-time (that is, during calls tosqlite3_step()). The purpose of the LIKELIHOOD(X,Y) function is to provide a hint to the query planner that the argument X is a boolean that is true with a probability of approximately Y. Theunlikely(X)function is short-hand for LIKELIHOOD(X,0.0625). Thelikely(X)function is short-hand for LIKELIHOOD(X,0.9375). ",
"LIKELY(X) \nDescription: The LIKELY(X) function returns the argument X unchanged. The LIKELY(X) function is a no-op that the code generator optimizes away so that it consumes no CPU cycles at run-time (that is, during calls tosqlite3_step()). The purpose of the LIKELY(X) function is to provide a hint to the query planner that the argument X is a boolean value that is usually true. The LIKELY(X) function is equivalent tolikelihood(X,0.9375). See also:unlikely(X). ",
"LOAD_EXTENSION(X) or LOAD_EXTENSION(X,Y) \nDescription: The LOAD_EXTENSION(X,Y) function loadsSQLite extensionsout of the shared library file named X using the entry point Y. The result of LOAD_EXTENSION() is always a NULL. If Y is omitted then the default entry point name is used. The LOAD_EXTENSION() function raises an exception if the extension fails to load or initialize correctly.The LOAD_EXTENSION() function will fail if the extension attempts to modify or delete an SQL function or collating sequence. The extension can add new functions or collating sequences, but cannot modify or delete existing functions or collating sequences because those functions and/or collating sequences might be used elsewhere in the currently running SQL statement. To load an extension that changes or deletes functions or collating sequences, use thesqlite3_load_extension()C-language API.For security reasons, extension loading is disabled by default and must be enabled by a prior call tosqlite3_enable_load_extension(). ",
"LOWER(X) \nDescription: The LOWER(X) function returns a copy of string X with all ASCII characters converted to lower case. The default built-in LOWER() function works for ASCII characters only. To do case conversions on non-ASCII characters, load the ICU extension. ",
"LTRIM(X) or LTRIM(X,Y) \nDescription: The LTRIM(X,Y) function returns a string formed by removing any and all characters that appear in Y from the left side of X. If the Y argument is omitted, LTRIM(X) removes spaces from the left side of X. ",
"MAX(X,Y,...) \nDescription: The multi-argument MAX() function returns the argument with the maximum value, or return NULL if any argument is NULL. The multi-argument MAX() function searches its arguments from left to right for an argument that defines a collating function and uses that collating function for all string comparisons. If none of the arguments to MAX() define a collating function, then the BINARY collating function is used. Note thatmax()is a simple function when it has 2 or more arguments but operates as anaggregate functionif given only a single argument. ",
"MIN(X,Y,...) \nDescription: The multi-argument MIN() function returns the argument with the minimum value. The multi-argument MIN() function searches its arguments from left to right for an argument that defines a collating function and uses that collating function for all string comparisons. If none of the arguments to MIN() define a collating function, then the BINARY collating function is used. Note thatmin()is a simple function when it has 2 or more arguments but operates as anaggregate functionif given only a single argument. ",
"NULLIF(X,Y) \nDescription: The NULLIF(X,Y) function returns its first argument if the arguments are different and NULL if the arguments are the same. The NULLIF(X,Y) function searches its arguments from left to right for an argument that defines a collating function and uses that collating function for all string comparisons. If neither argument to NULLIF() defines a collating function then the BINARY collating function is used. ",
"OCTET_LENGTH(X) \nDescription: The OCTET_LENGTH(X) function returns the number of bytes in the encoding of text string X. If X is NULL then OCTET_LENGTH(X) returns NULL. If X is a BLOB value, then OCTET_LENGTH(X) is the same aslength(X). If X is a numeric value, then OCTET_LENGTH(X) returns the number of bytes in a text rendering of that number.Because OCTET_LENGTH(X) returns the number of bytes in X, not the number of characters, the value returned depends on the database encoding. The OCTET_LENGTH() function can return different answers for the same input string if the database encoding is UTF16 instead of UTF8.If argument X is a table column and the value is of type text or blob, then OCTET_LENGTH(X) avoids reading the content of X from disk, as the byte length can be computed from metadata. Thus, OCTET_LENGTH(X) is efficient even if X is a column containing a multi-megabyte text or blob value. ",
"PRINTF(FORMAT,...) \nDescription: The PRINTF() SQL function is an alias for theformat() SQL function. The format() SQL function was originally named PRINTF(). But the name was later changed to format() for compatibility with other database engines. The PRINTF() name is retained as an alias so as not to break legacy code. ",
"QUOTE(X) \nDescription: The QUOTE(X) function returns the text of an SQL literal which is the value of its argument suitable for inclusion into an SQL statement. Strings are surrounded by single-quotes with escapes on interior quotes as needed. BLOBs are encoded as hexadecimal literals. Strings with embedded NUL characters cannot be represented as string literals in SQL and hence the returned string literal is truncated prior to the first NUL. ",
"RANDOM() \nDescription: The RANDOM() function returns a pseudo-random integer between -9223372036854775808 and +9223372036854775807. ",
"RANDOMBLOB(N) \nDescription: The RANDOMBLOB(N) function return an N-byte blob containing pseudo-random bytes. If N is less than 1 then a 1-byte random blob is returned.Hint: applications can generate globally unique identifiers using this function together withhex()and/orlower()like this:hex(randomblob(16))lower(hex(randomblob(16))) ",
"REPLACE(X,Y,Z) \nDescription: The REPLACE(X,Y,Z) function returns a string formed by substituting string Z for every occurrence of string Y in string X. TheBINARYcollating sequence is used for comparisons. If Y is an empty string then return X unchanged. If Z is not initially a string, it is cast to a UTF-8 string prior to processing. ",
"ROUND(X) or ROUND(X,Y) \nDescription: The ROUND(X,Y) function returns a floating-point value X rounded to Y digits to the right of the decimal point. If the Y argument is omitted or negative, it is taken to be 0. ",
"RTRIM(X) or RTRIM(X,Y) \nDescription: The RTRIM(X,Y) function returns a string formed by removing any and all characters that appear in Y from the right side of X. If the Y argument is omitted, RTRIM(X) removes spaces from the right side of X. ",
"SIGN(X) \nDescription: The SIGN(X) function returns -1, 0, or +1 if the argument X is a numeric value that is negative, zero, or positive, respectively. If the argument to SIGN(X) is NULL or is a string or blob that cannot be losslessly converted into a number, then SIGN(X) returns NULL. ",
"SOUNDEX(X) \nDescription: The SOUNDEX(X) function returns a string that is the soundex encoding of the string X. The string \"?000\" is returned if the argument is NULL or contains no ASCII alphabetic characters. This function is omitted from SQLite by default. It is only available if theSQLITE_SOUNDEXcompile-time option is used when SQLite is built. ",
"SQLITE_COMPILEOPTION_GET(N) \nDescription: The SQLITE_COMPILEOPTION_GET() SQL function is a wrapper around thesqlite3_compileoption_get()C/C++ function. This routine returns the N-th compile-time option used to build SQLite or NULL if N is out of range. See also thecompile_options pragma. ",
"SQLITE_COMPILEOPTION_USED(X) \nDescription: The SQLITE_COMPILEOPTION_USED() SQL function is a wrapper around thesqlite3_compileoption_used()C/C++ function. When the argument X to SQLITE_COMPILEOPTION_USED(X) is a string which is the name of a compile-time option, this routine returns true (1) or false (0) depending on whether or not that option was used during the build. ",
"SQLITE_OFFSET(X) \nDescription: The SQLITE_OFFSET(X) function returns the byte offset in the database file for the beginning of the record from which value would be read. If X is not a column in an ordinary table, then SQLITE_OFFSET(X) returns NULL. The value returned by SQLITE_OFFSET(X) might reference either the original table or an index, depending on the query. If the value X would normally be extracted from an index, the SQLITE_OFFSET(X) returns the offset to the corresponding index record. If the value X would be extracted from the original table, then SQLITE_OFFSET(X) returns the offset to the table record.The SQLITE_OFFSET(X) SQL function is only available if SQLite is built using the-DSQLITE_ENABLE_OFFSET_SQL_FUNCcompile-time option. ",
"SQLITE_SOURCE_ID() \nDescription: The SQLITE_SOURCE_ID() function returns a string that identifies the specific version of the source code that was used to build the SQLite library. The string returned by SQLITE_SOURCE_ID() is the date and time that the source code was checked in followed by the SHA3-256 hash for that check-in. This function is an SQL wrapper around thesqlite3_sourceid()C interface. ",
"SQLITE_VERSION() \nDescription: The SQLITE_VERSION() function returns the version string for the SQLite library that is running. This function is an SQL wrapper around thesqlite3_libversion()C-interface. ",
"SUBSTR(X,Y,Z) or SUBSTR(X,Y) or SUBSTRING(X,Y,Z) or SUBSTRING(X,Y) \nDescription: The SUBSTR(X,Y,Z) function returns a substring of input string X that begins with the Y-th character and which is Z characters long. If Z is omitted then SUBSTR(X,Y) returns all characters through the end of the string X beginning with the Y-th. The left-most character of X is number 1. If Y is negative then the first character of the substring is found by counting from the right rather than the left. If Z is negative then the abs(Z) characters preceding the Y-th character are returned. If X is a string then characters indices refer to actual UTF-8 characters. If X is a BLOB then the indices refer to bytes.\"substring()\" is an alias for \"substr()\" beginning with SQLite version 3.34. ",
"TOTAL_CHANGES() \nDescription: The TOTAL_CHANGES() function returns the number of row changes caused by INSERT, UPDATE or DELETE statements since the current database connection was opened. This function is a wrapper around thesqlite3_total_changes64()C/C++ interface. ",
"TRIM(X) or TRIM(X,Y) \nDescription: The TRIM(X,Y) function returns a string formed by removing any and all characters that appear in Y from both ends of X. If the Y argument is omitted, TRIM(X) removes spaces from both ends of X. ",
"TYPEOF(X) \nDescription: The TYPEOF(X) function returns a string that indicates thedatatypeof the expression X: \"null\", \"integer\", \"real\", \"text\", or \"blob\". ",
"UNHEX(X) or UNHEX(X,Y) \nDescription: The UNHEX(X,Y) function returns a BLOB value which is the decoding of the hexadecimal string X. If X contains any characters that are not hexadecimal digits and which are not in Y, then UNHEX(X,Y) returns NULL. If Y is omitted, it is understood to be an empty string and hence X must be a pure hexadecimal string. All hexadecimal digits in X must occur in pairs, with both digits of each pair beginning immediately adjacent to one another, or else UNHEX(X,Y) returns NULL. If either parameter X or Y is NULL, then UNHEX(X,Y) returns NULL. The X input may contain an arbitrary mix of upper and lower case hexadecimal digits. Hexadecimal digits in Y have no affect on the translation of X. Only characters in Y that are not hexadecimal digits are ignored in X.See also:hex() ",
"UNICODE(X) \nDescription: The UNICODE(X) function returns the numeric unicode code point corresponding to the first character of the string X. If the argument to UNICODE(X) is not a string then the result is undefined. ",
"UNLIKELY(X) \nDescription: The UNLIKELY(X) function returns the argument X unchanged. The UNLIKELY(X) function is a no-op that the code generator optimizes away so that it consumes no CPU cycles at run-time (that is, during calls tosqlite3_step()). The purpose of the UNLIKELY(X) function is to provide a hint to the query planner that the argument X is a boolean value that is usually not true. The UNLIKELY(X) function is equivalent tolikelihood(X, 0.0625). ",
"UPPER(X) \nDescription: The UPPER(X) function returns a copy of input string X in which all lower-case ASCII characters are converted to their upper-case equivalent. ",
"ZEROBLOB(N) \nDescription: The ZEROBLOB(N) function returns a BLOB consisting of N bytes of 0x00. SQLite manages these zeroblobs very efficiently. Zeroblobs can be used to reserve space for a BLOB that is later written usingincremental BLOB I/O. This SQL function is implemented using thesqlite3_result_zeroblob()routine from the C/C++ interface. ",
"AVG(X) \nDescription: The AVG() function returns the average value of all non-NULLXwithin a group. String and BLOB values that do not look like numbers are interpreted as 0. The result of AVG() is always a floating point value whenever there is at least one non-NULL input even if all inputs are integers. The result of AVG() is NULL if there are no non-NULL inputs. The result of AVG() is computed astotal()/count()so all of the constraints that apply tototal()also apply to AVG(). ",
"COUNT(X) or COUNT(*) \nDescription: The COUNT(X) function returns a count of the number of times thatXis not NULL in a group. The COUNT(*) function (with no arguments) returns the total number of rows in the group. ",
"GROUP_CONCAT(X) or GROUP_CONCAT(X,Y) or STRING_AGG(X,Y) \nDescription: The GROUP_CONCAT() function returns a string which is the concatenation of all non-NULL values ofX. If parameterYis present then it is used as the separator between instances ofX.A comma (\",\") is used as the separator ifYis omitted.The string_agg(X,Y) function is an alias for GROUP_CONCAT(X,Y). String_agg() is compatible with PostgreSQL and SQL-Server and GROUP_CONCAT() is compatible with MySQL.The order of the concatenated elements is arbitrary unless an ORDER BY argument is included immediately after the last parameter. ",
"MAX(X) \nDescription: The MAX() aggregate function returns the maximum value of all values in the group. The maximum value is the value that would be returned last in an ORDER BY on the same column. Aggregate MAX() returns NULL if and only if there are no non-NULL values in the group. ",
"MIN(X) \nDescription: The MIN() aggregate function returns the minimum non-NULL value of all values in the group. The minimum value is the first non-NULL value that would appear in an ORDER BY of the column. Aggregate MIN() returns NULL if and only if there are no non-NULL values in the group. ",
"SUM(X) or TOTAL(X) \nDescription: The SUM() and TOTAL() aggregate functions return the sum of all non-NULL values in the group. If there are no non-NULL input rows then SUM() returns NULL but TOTAL() returns 0.0. NULL is not normally a helpful result for the sum of no rows but the SQL standard requires it and most other SQL database engines implement SUM() that way so SQLite does it in the same way in order to be compatible. The non-standard TOTAL() function is provided as a convenient way to work around this design problem in the SQL language. ",
"ROW_NUMBER() \nDescription: The number of the row within the current partition. Rows are numbered starting from 1 in the order defined by the ORDER BY clause in the window definition, or in arbitrary order otherwise. ",
"RANK() \nDescription: The row_number() of the first peer in each group - the rank of the current row with gaps. If there is no ORDER BY clause, then all rows are considered peers and this function always returns 1. ",
"DENSE_RANK() \nDescription: The number of the current row's peer group within its partition - the rank of the current row without gaps. Rows are numbered starting from 1 in the order defined by the ORDER BY clause in the window definition. If there is no ORDER BY clause, then all rows are considered peers and this function always returns 1. ",
"PERCENT_RANK() \nDescription: Despite the name, this function always returns a value between 0.0 and 1.0 equal to (rank- 1)/(partition-rows- 1), whererankis the value returned by built-in window function rank() andpartition-rowsis the total number of rows in the partition. If the partition contains only one row, this function returns 0.0. ",
"CUME_DIST() \nDescription: The cumulative distribution. Calculated asrow-number/partition-rows, whererow-numberis the value returned by row_number() for the last peer in the group andpartition-rowsthe number of rows in the partition. ",
"NTILE(N) \nDescription: ArgumentNis handled as an integer. This function divides the partition into N groups as evenly as possible and assigns an integer between 1 andNto each group, in the order defined by the ORDER BY clause, or in arbitrary order otherwise. If necessary, larger groups occur first. This function returns the integer value assigned to the group that the current row is a part of. ",
"LAG(expr) or LAG(expr, offset) or LAG(expr, offset, default) \nDescription: The first form of the LAG() function returns the result of evaluating expressionexpragainst the previous row in the partition. Or, if there is no previous row (because the current row is the first), NULL. ",
"LEAD(expr) or LEAD(expr, offset) or LEAD(expr, offset, default) \nDescription: The first form of the LEAD() function returns the result of evaluating expressionexpragainst the next row in the partition. Or, if there is no next row (because the current row is the last), NULL. ",
"FIRST_VALUE(expr) \nDescription: This built-in window function calculates the window frame for each row in the same way as an aggregate window function. It returns the value ofexprevaluated against the first row in the window frame for each row. ",
"LAST_VALUE(expr) \nDescription: This built-in window function calculates the window frame for each row in the same way as an aggregate window function. It returns the value ofexprevaluated against the last row in the window frame for each row. ",
"NTH_VALUE(expr, N) \nDescription: This built-in window function calculates the window frame for each row in the same way as an aggregate window function. It returns the value ofexprevaluated against the rowNof the window frame. Rows are numbered within the window frame starting from 1 in the order defined by the ORDER BY clause if one is present, or in arbitrary order otherwise. If there is noNth row in the partition, then NULL is returned. ",
"ACOS(X) \nDescription: Return the arccosine of X. The result is in radians. ",
"ACOSH(X) \nDescription: Return the hyperbolic arccosine of X. ",
"ASIN(X) \nDescription: Return the arcsine of X. The result is in radians. ",
"ASINH(X) \nDescription: Return the hyperbolic arcsine of X. ",
"ATAN(X) \nDescription: Return the arctangent of X. The result is in radians. ",
"ATAN2(Y,X) \nDescription: Return the arctangent of Y/X. The result is in radians. The result is placed into correct quadrant depending on the signs of X and Y. ",
"ATANH(X) \nDescription: Return the hyperbolic arctangent of X. ",
"CEIL(X) or CEILING(X) \nDescription: Return the first representable integer value greater than or equal to X. For positive values of X, this routine rounds away from zero. For negative values of X, this routine rounds toward zero. ",
"COS(X) \nDescription: Return the cosine of X. X is in radians. ",
"COSH(X) \nDescription: Return the hyperbolic cosine of X. ",
"DEGREES(X) \nDescription: Convert value X from radians into degrees. ",
"EXP(X) \nDescription: Computee(Euler's number, approximately 2.71828182845905) raised to the power X. ",
"FLOOR(X) \nDescription: Return the first representable integer value less than or equal to X. For positive numbers, this function rounds toward zero. For negative numbers, this function rounds away from zero. ",
"LN(X) \nDescription: Return the natural logarithm of X. ",
"LOG(X) or LOG10(X) or LOG(B,X) \nDescription: Return the base-10 logarithm for X. Or, for the two-argument version, return the base-B logarithm of X.Compatibility note: SQLite works like PostgreSQL in that the LOG() function computes a base-10 logarithm. Most other SQL database engines compute a natural logarithm for LOG(). In the two-argument version of LOG(B,X), the first argument is the base and the second argument is the operand. This is the same as in PostgreSQL and MySQL, but is reversed from SQL Server which uses the second argument as the base and the first argument as the operand. ",
"LOG2(X) \nDescription: Return the logarithm base-2 for the number X. ",
"MOD(X,Y) \nDescription: Return the remainder after dividing X by Y. This is similar to the '%' operator, except that it works for non-integer arguments. ",
"PI() \nDescription: Return an approximation for π. ",
"POW(X,Y) or POWER(X,Y) \nDescription: Compute X raised to the power Y. ",
"RADIANS(X) \nDescription: Convert X from degrees into radians. ",
"SIN(X) \nDescription: Return the sine of X. X is in radians. ",
"SINH(X) \nDescription: Return the hyperbolic sine of X. ",
"SQRT(X) \nDescription: Return the square root of X. NULL is returned if X is negative. ",
"TAN(X) \nDescription: Return the tangent of X. X is in radians. ",
"TANH(X) \nDescription: Return the hyperbolic tangent of X. ",
"TRUNC(X) \nDescription: Return the representable integer in between X and 0 (inclusive) that is furthest away from zero. Or, in other words, return the integer part of X, rounding toward zero. The TRUNC() function is similar toceiling(X)andfloor(X)except that it always rounds toward zero whereas ceiling(X) and floor(X) round up and down, respectively. ",
"DATE(time-value, modifier, modifier, ...) \nDescription: Returns the date as text in this format: YYYY-MM-DD. ",
"TIME(time-value, modifier, modifier, ...) \nDescription: Returns the time as text in formatted as HH:MM:SS or as HH:MM:SS.SSS if the subsec modifier is used. ",
"DATETIME(time-value, modifier, modifier, ...) \nDescription: Returns the date and time formatted as YYYY-MM-DD HH:MM:SS or as YYYY-MM-DD HH:MM:SS.SSS if the subsec modifier is used. ",
"JULIANDAY(time-value, modifier, modifier, ...) \nDescription: Returns the Julian day - the fractional number of days since noon in Greenwich on November 24, 4714 B.C. (Proleptic Gregorian calendar). ",
"UNIXEPOCH(time-value, modifier, modifier, ...) \nDescription: Returns a unix timestamp - the number of seconds since 1970-01-01 00:00:00 UTC. The UNIXEPOCH() function normally returns an integer number of seconds, but with the optional subsec modifier it will return a floating point number which is the fractional number of seconds. ",
"STRFTIME(format, time-value, modifier, modifier, ...) \nDescription: Returns the date formatted according to the format string specified as the first argument. The format string supports the most common substitutions found in the STRFTIME() function from the standard C library plus two new substitutions, %f and %J. ",
"TIMEDIFF(time-value, time-value) \nDescription: Returns a string that describes the amount of time that must be added to B in order to reach time A. The format of the TIMEDIFF() result is designed to be human-readable. "
]
def _sql_func_template(self, sql_funcs: str) -> str:
template = """### SQL Functions
You may consider one or more of the following SQL functions while generating the query:
{sql_funcs}
Important tips:
Except for the functions listed above, you may use any other functions as long as they conform to the syntax of the database engine.
"""
return template.format(sql_funcs=sql_funcs)
def _insert_stmts_template(self, insert_statements: str) -> str:
template = '''### INSERT INTO Statements
Below are several `INSERT INTO` statements. Use these to help generate predicates (i.e., `WHERE` clauses) in your SQL query:
{insert_statements}
'''
return template.format(insert_statements=insert_statements)
def _sql_synthesis_prompt(self, schema_str: str, sql_function_prompt: str, db_value_prompt: str, complexity: str, criterion: str, db_engine: str, column_count: int) -> str:
template = '''**Task Overview**
Create an executable SQL query based on the provided information.
**Database Schema**
{schema_str}
{sql_function_prompt}
{db_value_prompt}
**SQL Query Complexity**
Ensure the SQL query matches the {complexity} level, defined as follows:
{criterion}
**Output Format Requirements**
Enclose the SQL query in a code block:
```sql
-- Your SQL query here
```
**SQL Query Requirements**
1. Use the syntax specific to the {db_engine} database engine.
2. Incorporate advanced functions if appropriate, but they are not mandatory.
3. Address real-world data analysis needs. Avoid trivial or nonsensical queries.
4. (Very important) Ensure the final SQL query selects {column_count} columns.
**Answer**
Let's proceed step by step.
'''
return template.format(
schema_str=schema_str,
sql_function_prompt=sql_function_prompt.strip(),
db_value_prompt=db_value_prompt.strip(),
complexity=complexity,
criterion=criterion.strip(),
db_engine=db_engine,
column_count=column_count
)
def build_prompt(self, insert_statements: List[str], create_statements: List[str], db_engine: str) -> str:
random.seed(42)
complexity = random.sample(["Simple", "Moderate", "Complex", "Highly Complex"], 1)[0]
if len(insert_statements) == 0:
db_value_prompt = ""
else:
if len(insert_statements) > 4:
insert_statements = random.sample(insert_statements, 4)
db_value_prompt = self._insert_stmts_template(
insert_statements="\n\n".join(insert_statements)
)
function_num = random.randint(0, 2)
if function_num == 0:
sql_function_prompt = "### SQL Functions\nYou can use any function supported by the database engine."
else:
sql_funcs = ""
sampled_functions = random.sample(self.functions, function_num)
for idx, func in enumerate(sampled_functions):
sql_funcs += f"Function {idx + 1}:\n{func.strip()}\n"
sql_function_prompt = self._sql_func_template(sql_funcs=sql_funcs)
column_count = np.random.geometric(0.6, 1)[0]
prompt = self._sql_synthesis_prompt(
schema_str="\n\n".join(create_statements),
sql_function_prompt=sql_function_prompt.strip(),
db_value_prompt=db_value_prompt.strip(),
complexity=complexity,
criterion=self.complexity2criterion[complexity].strip(),
db_engine=db_engine,
column_count=column_count
)
return prompt
@PROMPT_REGISTRY.register()
class SelectVecSQLGeneratorPrompt(PromptABC):
def __init__(self):
self.simple_vec_criterion = '''**Criteria:**
Simple KNN queries in SQLite-vec may satisfy one or more of the following criteria:
- Basic vector similarity search on a single table
- Uses simple `MATCH` operator with target vector
- Contains basic `LIMIT` or `AND` clause to restrict results after `MATCH` operator
- No joins or complex filtering beyond the vector search
**Example of Simple KNN Query:**
```sql
SELECT rowid, location_embedding
FROM vec_table
WHERE location_embedding MATCH lembed('all-MiniLM-L6-v2',"572 Main Street Los Angeles, CA 90210 USA")
ORDER BY distance
LIMIT 1;
```'''
self.moderate_vec_criterion = '''**Criteria:**
Moderate KNN queries in SQLite-vec may satisfy one or more of the following criteria:
- Includes simple joins with metadata tables
- Contains basic post-filtering of vector results
- May use multiple vector columns in query
**Example of Moderate KNN Query:**
```sql
SELECT d.doc_id, d.title, d.content
FROM documents d
JOIN categories c ON d.category_id = c.id
WHERE d.content_embedding MATCH lembed('all-MiniLM-L6-v2',"OmniSQL is a unified SQL engine that integrates Vector search and LLM augmentation.")
AND k = 2
AND c.name = 'science'
ORDER BY d.distance;
```'''
self.complex_vec_criterion = '''**Criteria:**
Complex KNN queries in SQLite-vec may satisfy one or more of the following criteria:
- Combines vector search with complex joins
- Uses CTEs to organize vector search logic
- Contains hybrid search (vector + full-text)
- Implements multi-stage filtering of results
- May use window functions with vector results
- Includes complex distance threshold conditions
**Example of Complex KNN Query:**
```sql
WITH HighWDVOATeams AS (
SELECT team_id, team_name
FROM teams
WHERE team_id IN (
SELECT team_id
FROM team_metrics
WHERE wdvoa > 30 AND season = 2019
)
),
SimilarTeams AS (
SELECT team_id, team_name, distance
FROM teams
WHERE team_name_embedding MATCH lembed('all-MiniLM-L6-v2',"Woven Shadows")
ORDER BY distance
LIMIT 5
)
SELECT h.team_name, AVG(p.confidence_level) AS average_confidence
FROM HighWDVOATeams h
JOIN SimilarTeams s ON h.team_id = s.team_id
JOIN game_predictions p ON p.game_id IN (
SELECT game_id
FROM games
WHERE home_team_id = h.team_id OR away_team_id = h.team_id
)
GROUP BY h.team_name
HAVING average_confidence > 0.7;
```'''
self.highly_complex_vec_criterion = '''**Criteria:**
Highly complex KNN queries in SQLite-vec may satisfy one or more of the following criteria:
- Uses multiple CTEs with vector operations
- Combines multiple vector searches in one query
- Implements advanced hybrid search techniques
- Contains recursive vector search patterns
- Uses complex window functions over vector results
- May involve vector aggregation operations
- Implements custom distance calculations
**Example of Highly Complex KNN Query:**
```sql
WITH BettingAnalysis AS (
SELECT
g.game_id,
AVG(bd.betting_spread) AS avg_initial_spread,
COUNT(*) AS total_bets
FROM games g
JOIN betting_data bd ON g.game_id = bd.game_id
GROUP BY g.game_id
),
PredictionAnalysis AS (
SELECT
gp.game_id,
AVG(gp.confidence_level) AS avg_confidence,
SUM(CASE WHEN gp.make_pick = 1 AND g.pick_right = 1 THEN 1 ELSE 0 END) AS correct_predictions
FROM games g
JOIN game_predictions gp ON g.game_id = gp.game_id
GROUP BY gp.game_id
),
TeamPerformance AS (
SELECT
tm.team_id,
tm.season,
AVG(tm.wdvoa) AS avg_wdvoa
FROM team_metrics tm
GROUP BY tm.team_id, tm.season
),
LocationSimilarity AS (
SELECT
g.game_id,
g.location,
vec.distance AS location_similarity
FROM games g
JOIN (
SELECT rowid, distance
FROM games
WHERE location_embedding MATCH lembed('all-MiniLM-L6-v2',"New York")
ORDER BY distance
LIMIT 5
) AS vec ON g.rowid = vec.rowid
)
SELECT
g.game_id,
ba.avg_initial_spread,
pa.avg_confidence,
tp.avg_wdvoa,
ls.location_similarity
FROM games g
JOIN BettingAnalysis ba ON g.game_id = ba.game_id
JOIN PredictionAnalysis pa ON g.game_id = pa.game_id
JOIN TeamPerformance tp ON g.home_team_id = tp.team_id
JOIN LocationSimilarity ls ON g.game_id = ls.game_id
WHERE pa.correct_predictions > 2
ORDER BY g.game_id;
```'''
self.complexity2criterion_vec = {
"Simple": self.simple_vec_criterion,
"Moderate": self.moderate_vec_criterion,
"Complex": self.complex_vec_criterion,
"Highly Complex": self.highly_complex_vec_criterion
}
self.functions = [
"ABS(X) \nDescription: The ABS(X) function returns the absolute value of the numeric argument X. Abs(X) returns NULL if X is NULL. Abs(X) returns 0.0 if X is a string or blob that cannot be converted to a numeric value. If X is the integer -9223372036854775808 then ABS(X) throws an integer overflow error since there is no equivalent positive 64-bit two complement value. ",
"CHANGES() \nDescription: The CHANGES() function returns the number of database rows that were changed or inserted or deleted by the most recently completed INSERT, DELETE, or UPDATE statement, exclusive of statements in lower-level triggers. The CHANGES() SQL function is a wrapper around thesqlite3_changes64()C/C++ function and hence follows the same rules for counting changes. ",
"CHAR(X1,X2,...,XN) \nDescription: The CHAR(X1,X2,...,XN) function returns a string composed of characters having the unicode code point values of integers X1 through XN, respectively. ",
"COALESCE(X,Y,...) \nDescription: The COALESCE() function returns a copy of its first non-NULL argument, or NULL if all arguments are NULL. Coalesce() must have at least 2 arguments. ",
"CONCAT(X,...) \nDescription: The CONCAT(...) function returns a string which is the concatenation of the string representation of all of its non-NULL arguments. If all arguments are NULL, then CONCAT() returns an empty string. ",
"CONCAT_WS(SEP,X,...) \nDescription: The CONCAT_WS(SEP,...) function returns a string that is the concatenation of all non-null arguments beyond the first argument, using the text value of the first argument as a separator. If the first argument is NULL, then CONCAT_WS() returns NULL. If all arguments other than the first are NULL, then CONCAT_WS() returns an empty string. ",
"FORMAT(FORMAT,...) \nDescription: The FORMAT(FORMAT,...) SQL function works like thesqlite3_mprintf()C-language function and the printf() function from the standard C library. The first argument is a format string that specifies how to construct the output string using values taken from subsequent arguments. If the FORMAT argument is missing or NULL then the result is NULL. The %n format is silently ignored and does not consume an argument. The %p format is an alias for %X. The %z format is interchangeable with %s. If there are too few arguments in the argument list, missing arguments are assumed to have a NULL value, which is translated into 0 or 0.0 for numeric formats or an empty string for %s. See thebuilt-in printf()documentation for additional information. ",
"GLOB(X,Y) \nDescription: The GLOB(X,Y) function is equivalent to the expression \"Y GLOB X\". Note that the X and Y arguments are reversed in the GLOB() function relative to the infixGLOBoperator. Y is the string and X is the pattern. So, for example, the following expressions are equivalent:name GLOB '*helium*' GLOB('*helium*',name)If thesqlite3_create_function()interface is used to override the GLOB(X,Y) function with an alternative implementation then theGLOBoperator will invoke the alternative implementation. ",
"HEX(X) \nDescription: The HEX() function interprets its argument as a BLOB and returns a string which is the upper-case hexadecimal rendering of the content of that blob.If the argumentXin \"hex(X)\" is an integer or floating point number, then \"interprets its argument as a BLOB\" means that the binary number is first converted into a UTF8 text representation, then that text is interpreted as a BLOB. Hence, \"hex(12345678)\" renders as \"3132333435363738\" not the binary representation of the integer value \"0000000000BC614E\".See also:unhex() ",
"IFNULL(X,Y) \nDescription: The IFNULL() function returns a copy of its first non-NULL argument, or NULL if both arguments are NULL. Ifnull() must have exactly 2 arguments. The IFNULL() function is equivalent tocoalesce()with two arguments. ",
"IIF(X,Y,Z) \nDescription: The IIF(X,Y,Z) function returns the value Y if X is true, and Z otherwise. The IIF(X,Y,Z) function is logically equivalent to and generates the samebytecodeas theCASE expression\"CASE WHEN X THEN Y ELSE Z END\". ",
"INSTR(X,Y) \nDescription: The INSTR(X,Y) function finds the first occurrence of string Y within string X and returns the number of prior characters plus 1, or 0 if Y is nowhere found within X. Or, if X and Y are both BLOBs, then INSTR(X,Y) returns one more than the number bytes prior to the first occurrence of Y, or 0 if Y does not occur anywhere within X. If both arguments X and Y to INSTR(X,Y) are non-NULL and are not BLOBs then both are interpreted as strings. If either X or Y are NULL in INSTR(X,Y) then the result is NULL. ",
"LAST_INSERT_ROWID() \nDescription: The LAST_INSERT_ROWID() function returns theROWIDof the last row insert from the database connection which invoked the function. The LAST_INSERT_ROWID() SQL function is a wrapper around thesqlite3_last_insert_rowid()C/C++ interface function. ",
"LENGTH(X) \nDescription: For a string value X, the LENGTH(X) function returns the number of characters (not bytes) in X prior to the first NUL character. Since SQLite strings do not normally contain NUL characters, the LENGTH(X) function will usually return the total number of characters in the string X. For a blob value X, LENGTH(X) returns the number of bytes in the blob. If X is NULL then LENGTH(X) is NULL. If X is numeric then LENGTH(X) returns the length of a string representation of X.Note that for strings, the LENGTH(X) function returns thecharacterlength of the string, not the byte length. The character length is the number of characters in the string. The character length is always different from the byte length for UTF-16 strings, and can be different from the byte length for UTF-8 strings if the string contains multi-byte characters. Use theoctet_length()function to find the byte length of a string.For BLOB values, LENGTH(X) always returns the byte-length of the BLOB.For string values, LENGTH(X) must read the entire string into memory in order to compute the character length. But for BLOB values, that is not necessary as SQLite knows how many bytes are in the BLOB. Hence, for multi-megabyte values, the LENGTH(X) function is usually much faster for BLOBs than for strings, since it does not need to load the value into memory. ",
"LIKE(X,Y) or LIKE(X,Y,Z) \nDescription: The LIKE() function is used to implement the \"Y LIKE X [ESCAPE Z]\" expression. If the optional ESCAPE clause is present, then the LIKE() function is invoked with three arguments. Otherwise, it is invoked with two arguments only. Note that the X and Y parameters are reversed in the LIKE() function relative to the infixLIKEoperator. X is the pattern and Y is the string to match against that pattern. Hence, the following expressions are equivalent:name LIKE '%neon%' LIKE('%neon%',name)Thesqlite3_create_function()interface can be used to override the LIKE() function and thereby change the operation of theLIKEoperator. When overriding the LIKE() function, it may be important to override both the two and three argument versions of the LIKE() function. Otherwise, different code may be called to implement theLIKEoperator depending on whether or not an ESCAPE clause was specified. ",
"LIKELIHOOD(X,Y) \nDescription: The LIKELIHOOD(X,Y) function returns argument X unchanged. The value Y in LIKELIHOOD(X,Y) must be a floating point constant between 0.0 and 1.0, inclusive. The LIKELIHOOD(X) function is a no-op that the code generator optimizes away so that it consumes no CPU cycles during run-time (that is, during calls tosqlite3_step()). The purpose of the LIKELIHOOD(X,Y) function is to provide a hint to the query planner that the argument X is a boolean that is true with a probability of approximately Y. Theunlikely(X)function is short-hand for LIKELIHOOD(X,0.0625). Thelikely(X)function is short-hand for LIKELIHOOD(X,0.9375). ",
"LIKELY(X) \nDescription: The LIKELY(X) function returns the argument X unchanged. The LIKELY(X) function is a no-op that the code generator optimizes away so that it consumes no CPU cycles at run-time (that is, during calls tosqlite3_step()). The purpose of the LIKELY(X) function is to provide a hint to the query planner that the argument X is a boolean value that is usually true. The LIKELY(X) function is equivalent tolikelihood(X,0.9375). See also:unlikely(X). ",
"LOAD_EXTENSION(X) or LOAD_EXTENSION(X,Y) \nDescription: The LOAD_EXTENSION(X,Y) function loadsSQLite extensionsout of the shared library file named X using the entry point Y. The result of LOAD_EXTENSION() is always a NULL. If Y is omitted then the default entry point name is used. The LOAD_EXTENSION() function raises an exception if the extension fails to load or initialize correctly.The LOAD_EXTENSION() function will fail if the extension attempts to modify or delete an SQL function or collating sequence. The extension can add new functions or collating sequences, but cannot modify or delete existing functions or collating sequences because those functions and/or collating sequences might be used elsewhere in the currently running SQL statement. To load an extension that changes or deletes functions or collating sequences, use thesqlite3_load_extension()C-language API.For security reasons, extension loading is disabled by default and must be enabled by a prior call tosqlite3_enable_load_extension(). ",
"LOWER(X) \nDescription: The LOWER(X) function returns a copy of string X with all ASCII characters converted to lower case. The default built-in LOWER() function works for ASCII characters only. To do case conversions on non-ASCII characters, load the ICU extension. ",
"LTRIM(X) or LTRIM(X,Y) \nDescription: The LTRIM(X,Y) function returns a string formed by removing any and all characters that appear in Y from the left side of X. If the Y argument is omitted, LTRIM(X) removes spaces from the left side of X. ",
"MAX(X,Y,...) \nDescription: The multi-argument MAX() function returns the argument with the maximum value, or return NULL if any argument is NULL. The multi-argument MAX() function searches its arguments from left to right for an argument that defines a collating function and uses that collating function for all string comparisons. If none of the arguments to MAX() define a collating function, then the BINARY collating function is used. Note thatmax()is a simple function when it has 2 or more arguments but operates as anaggregate functionif given only a single argument. ",
"MIN(X,Y,...) \nDescription: The multi-argument MIN() function returns the argument with the minimum value. The multi-argument MIN() function searches its arguments from left to right for an argument that defines a collating function and uses that collating function for all string comparisons. If none of the arguments to MIN() define a collating function, then the BINARY collating function is used. Note thatmin()is a simple function when it has 2 or more arguments but operates as anaggregate functionif given only a single argument. ",
"NULLIF(X,Y) \nDescription: The NULLIF(X,Y) function returns its first argument if the arguments are different and NULL if the arguments are the same. The NULLIF(X,Y) function searches its arguments from left to right for an argument that defines a collating function and uses that collating function for all string comparisons. If neither argument to NULLIF() defines a collating function then the BINARY collating function is used. ",
"OCTET_LENGTH(X) \nDescription: The OCTET_LENGTH(X) function returns the number of bytes in the encoding of text string X. If X is NULL then OCTET_LENGTH(X) returns NULL. If X is a BLOB value, then OCTET_LENGTH(X) is the same aslength(X). If X is a numeric value, then OCTET_LENGTH(X) returns the number of bytes in a text rendering of that number.Because OCTET_LENGTH(X) returns the number of bytes in X, not the number of characters, the value returned depends on the database encoding. The OCTET_LENGTH() function can return different answers for the same input string if the database encoding is UTF16 instead of UTF8.If argument X is a table column and the value is of type text or blob, then OCTET_LENGTH(X) avoids reading the content of X from disk, as the byte length can be computed from metadata. Thus, OCTET_LENGTH(X) is efficient even if X is a column containing a multi-megabyte text or blob value. ",
"PRINTF(FORMAT,...) \nDescription: The PRINTF() SQL function is an alias for theformat() SQL function. The format() SQL function was originally named PRINTF(). But the name was later changed to format() for compatibility with other database engines. The PRINTF() name is retained as an alias so as not to break legacy code. ",
"QUOTE(X) \nDescription: The QUOTE(X) function returns the text of an SQL literal which is the value of its argument suitable for inclusion into an SQL statement. Strings are surrounded by single-quotes with escapes on interior quotes as needed. BLOBs are encoded as hexadecimal literals. Strings with embedded NUL characters cannot be represented as string literals in SQL and hence the returned string literal is truncated prior to the first NUL. ",
"RANDOM() \nDescription: The RANDOM() function returns a pseudo-random integer between -9223372036854775808 and +9223372036854775807. ",
"RANDOMBLOB(N) \nDescription: The RANDOMBLOB(N) function return an N-byte blob containing pseudo-random bytes. If N is less than 1 then a 1-byte random blob is returned.Hint: applications can generate globally unique identifiers using this function together withhex()and/orlower()like this:hex(randomblob(16))lower(hex(randomblob(16))) ",
"REPLACE(X,Y,Z) \nDescription: The REPLACE(X,Y,Z) function returns a string formed by substituting string Z for every occurrence of string Y in string X. TheBINARYcollating sequence is used for comparisons. If Y is an empty string then return X unchanged. If Z is not initially a string, it is cast to a UTF-8 string prior to processing. ",
"ROUND(X) or ROUND(X,Y) \nDescription: The ROUND(X,Y) function returns a floating-point value X rounded to Y digits to the right of the decimal point. If the Y argument is omitted or negative, it is taken to be 0. ",
"RTRIM(X) or RTRIM(X,Y) \nDescription: The RTRIM(X,Y) function returns a string formed by removing any and all characters that appear in Y from the right side of X. If the Y argument is omitted, RTRIM(X) removes spaces from the right side of X. ",
"SIGN(X) \nDescription: The SIGN(X) function returns -1, 0, or +1 if the argument X is a numeric value that is negative, zero, or positive, respectively. If the argument to SIGN(X) is NULL or is a string or blob that cannot be losslessly converted into a number, then SIGN(X) returns NULL. ",
"SOUNDEX(X) \nDescription: The SOUNDEX(X) function returns a string that is the soundex encoding of the string X. The string \"?000\" is returned if the argument is NULL or contains no ASCII alphabetic characters. This function is omitted from SQLite by default. It is only available if theSQLITE_SOUNDEXcompile-time option is used when SQLite is built. ",
"SQLITE_COMPILEOPTION_GET(N) \nDescription: The SQLITE_COMPILEOPTION_GET() SQL function is a wrapper around thesqlite3_compileoption_get()C/C++ function. This routine returns the N-th compile-time option used to build SQLite or NULL if N is out of range. See also thecompile_options pragma. ",
"SQLITE_COMPILEOPTION_USED(X) \nDescription: The SQLITE_COMPILEOPTION_USED() SQL function is a wrapper around thesqlite3_compileoption_used()C/C++ function. When the argument X to SQLITE_COMPILEOPTION_USED(X) is a string which is the name of a compile-time option, this routine returns true (1) or false (0) depending on whether or not that option was used during the build. ",
"SQLITE_OFFSET(X) \nDescription: The SQLITE_OFFSET(X) function returns the byte offset in the database file for the beginning of the record from which value would be read. If X is not a column in an ordinary table, then SQLITE_OFFSET(X) returns NULL. The value returned by SQLITE_OFFSET(X) might reference either the original table or an index, depending on the query. If the value X would normally be extracted from an index, the SQLITE_OFFSET(X) returns the offset to the corresponding index record. If the value X would be extracted from the original table, then SQLITE_OFFSET(X) returns the offset to the table record.The SQLITE_OFFSET(X) SQL function is only available if SQLite is built using the-DSQLITE_ENABLE_OFFSET_SQL_FUNCcompile-time option. ",
"SQLITE_SOURCE_ID() \nDescription: The SQLITE_SOURCE_ID() function returns a string that identifies the specific version of the source code that was used to build the SQLite library. The string returned by SQLITE_SOURCE_ID() is the date and time that the source code was checked in followed by the SHA3-256 hash for that check-in. This function is an SQL wrapper around thesqlite3_sourceid()C interface. ",
"SQLITE_VERSION() \nDescription: The SQLITE_VERSION() function returns the version string for the SQLite library that is running. This function is an SQL wrapper around thesqlite3_libversion()C-interface. ",
"SUBSTR(X,Y,Z) or SUBSTR(X,Y) or SUBSTRING(X,Y,Z) or SUBSTRING(X,Y) \nDescription: The SUBSTR(X,Y,Z) function returns a substring of input string X that begins with the Y-th character and which is Z characters long. If Z is omitted then SUBSTR(X,Y) returns all characters through the end of the string X beginning with the Y-th. The left-most character of X is number 1. If Y is negative then the first character of the substring is found by counting from the right rather than the left. If Z is negative then the abs(Z) characters preceding the Y-th character are returned. If X is a string then characters indices refer to actual UTF-8 characters. If X is a BLOB then the indices refer to bytes.\"substring()\" is an alias for \"substr()\" beginning with SQLite version 3.34. ",
"TOTAL_CHANGES() \nDescription: The TOTAL_CHANGES() function returns the number of row changes caused by INSERT, UPDATE or DELETE statements since the current database connection was opened. This function is a wrapper around thesqlite3_total_changes64()C/C++ interface. ",
"TRIM(X) or TRIM(X,Y) \nDescription: The TRIM(X,Y) function returns a string formed by removing any and all characters that appear in Y from both ends of X. If the Y argument is omitted, TRIM(X) removes spaces from both ends of X. ",
"TYPEOF(X) \nDescription: The TYPEOF(X) function returns a string that indicates thedatatypeof the expression X: \"null\", \"integer\", \"real\", \"text\", or \"blob\". ",
"UNHEX(X) or UNHEX(X,Y) \nDescription: The UNHEX(X,Y) function returns a BLOB value which is the decoding of the hexadecimal string X. If X contains any characters that are not hexadecimal digits and which are not in Y, then UNHEX(X,Y) returns NULL. If Y is omitted, it is understood to be an empty string and hence X must be a pure hexadecimal string. All hexadecimal digits in X must occur in pairs, with both digits of each pair beginning immediately adjacent to one another, or else UNHEX(X,Y) returns NULL. If either parameter X or Y is NULL, then UNHEX(X,Y) returns NULL. The X input may contain an arbitrary mix of upper and lower case hexadecimal digits. Hexadecimal digits in Y have no affect on the translation of X. Only characters in Y that are not hexadecimal digits are ignored in X.See also:hex() ",
"UNICODE(X) \nDescription: The UNICODE(X) function returns the numeric unicode code point corresponding to the first character of the string X. If the argument to UNICODE(X) is not a string then the result is undefined. ",
"UNLIKELY(X) \nDescription: The UNLIKELY(X) function returns the argument X unchanged. The UNLIKELY(X) function is a no-op that the code generator optimizes away so that it consumes no CPU cycles at run-time (that is, during calls tosqlite3_step()). The purpose of the UNLIKELY(X) function is to provide a hint to the query planner that the argument X is a boolean value that is usually not true. The UNLIKELY(X) function is equivalent tolikelihood(X, 0.0625). ",
"UPPER(X) \nDescription: The UPPER(X) function returns a copy of input string X in which all lower-case ASCII characters are converted to their upper-case equivalent. ",
"ZEROBLOB(N) \nDescription: The ZEROBLOB(N) function returns a BLOB consisting of N bytes of 0x00. SQLite manages these zeroblobs very efficiently. Zeroblobs can be used to reserve space for a BLOB that is later written usingincremental BLOB I/O. This SQL function is implemented using thesqlite3_result_zeroblob()routine from the C/C++ interface. ",
"AVG(X) \nDescription: The AVG() function returns the average value of all non-NULLXwithin a group. String and BLOB values that do not look like numbers are interpreted as 0. The result of AVG() is always a floating point value whenever there is at least one non-NULL input even if all inputs are integers. The result of AVG() is NULL if there are no non-NULL inputs. The result of AVG() is computed astotal()/count()so all of the constraints that apply tototal()also apply to AVG(). ",
"COUNT(X) or COUNT(*) \nDescription: The COUNT(X) function returns a count of the number of times thatXis not NULL in a group. The COUNT(*) function (with no arguments) returns the total number of rows in the group. ",
"GROUP_CONCAT(X) or GROUP_CONCAT(X,Y) or STRING_AGG(X,Y) \nDescription: The GROUP_CONCAT() function returns a string which is the concatenation of all non-NULL values ofX. If parameterYis present then it is used as the separator between instances ofX.A comma (\",\") is used as the separator ifYis omitted.The string_agg(X,Y) function is an alias for GROUP_CONCAT(X,Y). String_agg() is compatible with PostgreSQL and SQL-Server and GROUP_CONCAT() is compatible with MySQL.The order of the concatenated elements is arbitrary unless an ORDER BY argument is included immediately after the last parameter. ",
"MAX(X) \nDescription: The MAX() aggregate function returns the maximum value of all values in the group. The maximum value is the value that would be returned last in an ORDER BY on the same column. Aggregate MAX() returns NULL if and only if there are no non-NULL values in the group. ",
"MIN(X) \nDescription: The MIN() aggregate function returns the minimum non-NULL value of all values in the group. The minimum value is the first non-NULL value that would appear in an ORDER BY of the column. Aggregate MIN() returns NULL if and only if there are no non-NULL values in the group. ",
"SUM(X) or TOTAL(X) \nDescription: The SUM() and TOTAL() aggregate functions return the sum of all non-NULL values in the group. If there are no non-NULL input rows then SUM() returns NULL but TOTAL() returns 0.0. NULL is not normally a helpful result for the sum of no rows but the SQL standard requires it and most other SQL database engines implement SUM() that way so SQLite does it in the same way in order to be compatible. The non-standard TOTAL() function is provided as a convenient way to work around this design problem in the SQL language. ",
"ROW_NUMBER() \nDescription: The number of the row within the current partition. Rows are numbered starting from 1 in the order defined by the ORDER BY clause in the window definition, or in arbitrary order otherwise. ",
"RANK() \nDescription: The row_number() of the first peer in each group - the rank of the current row with gaps. If there is no ORDER BY clause, then all rows are considered peers and this function always returns 1. ",
"DENSE_RANK() \nDescription: The number of the current row's peer group within its partition - the rank of the current row without gaps. Rows are numbered starting from 1 in the order defined by the ORDER BY clause in the window definition. If there is no ORDER BY clause, then all rows are considered peers and this function always returns 1. ",
"PERCENT_RANK() \nDescription: Despite the name, this function always returns a value between 0.0 and 1.0 equal to (rank- 1)/(partition-rows- 1), whererankis the value returned by built-in window function rank() andpartition-rowsis the total number of rows in the partition. If the partition contains only one row, this function returns 0.0. ",
"CUME_DIST() \nDescription: The cumulative distribution. Calculated asrow-number/partition-rows, whererow-numberis the value returned by row_number() for the last peer in the group andpartition-rowsthe number of rows in the partition. ",
"NTILE(N) \nDescription: ArgumentNis handled as an integer. This function divides the partition into N groups as evenly as possible and assigns an integer between 1 andNto each group, in the order defined by the ORDER BY clause, or in arbitrary order otherwise. If necessary, larger groups occur first. This function returns the integer value assigned to the group that the current row is a part of. ",
"LAG(expr) or LAG(expr, offset) or LAG(expr, offset, default) \nDescription: The first form of the LAG() function returns the result of evaluating expressionexpragainst the previous row in the partition. Or, if there is no previous row (because the current row is the first), NULL. ",
"LEAD(expr) or LEAD(expr, offset) or LEAD(expr, offset, default) \nDescription: The first form of the LEAD() function returns the result of evaluating expressionexpragainst the next row in the partition. Or, if there is no next row (because the current row is the last), NULL. ",
"FIRST_VALUE(expr) \nDescription: This built-in window function calculates the window frame for each row in the same way as an aggregate window function. It returns the value ofexprevaluated against the first row in the window frame for each row. ",
"LAST_VALUE(expr) \nDescription: This built-in window function calculates the window frame for each row in the same way as an aggregate window function. It returns the value ofexprevaluated against the last row in the window frame for each row. ",
"NTH_VALUE(expr, N) \nDescription: This built-in window function calculates the window frame for each row in the same way as an aggregate window function. It returns the value ofexprevaluated against the rowNof the window frame. Rows are numbered within the window frame starting from 1 in the order defined by the ORDER BY clause if one is present, or in arbitrary order otherwise. If there is noNth row in the partition, then NULL is returned. ",
"ACOS(X) \nDescription: Return the arccosine of X. The result is in radians. ",
"ACOSH(X) \nDescription: Return the hyperbolic arccosine of X. ",
"ASIN(X) \nDescription: Return the arcsine of X. The result is in radians. ",
"ASINH(X) \nDescription: Return the hyperbolic arcsine of X. ",
"ATAN(X) \nDescription: Return the arctangent of X. The result is in radians. ",
"ATAN2(Y,X) \nDescription: Return the arctangent of Y/X. The result is in radians. The result is placed into correct quadrant depending on the signs of X and Y. ",
"ATANH(X) \nDescription: Return the hyperbolic arctangent of X. ",
"CEIL(X) or CEILING(X) \nDescription: Return the first representable integer value greater than or equal to X. For positive values of X, this routine rounds away from zero. For negative values of X, this routine rounds toward zero. ",
"COS(X) \nDescription: Return the cosine of X. X is in radians. ",
"COSH(X) \nDescription: Return the hyperbolic cosine of X. ",
"DEGREES(X) \nDescription: Convert value X from radians into degrees. ",
"EXP(X) \nDescription: Computee(Euler's number, approximately 2.71828182845905) raised to the power X. ",
"FLOOR(X) \nDescription: Return the first representable integer value less than or equal to X. For positive numbers, this function rounds toward zero. For negative numbers, this function rounds away from zero. ",
"LN(X) \nDescription: Return the natural logarithm of X. ",
"LOG(X) or LOG10(X) or LOG(B,X) \nDescription: Return the base-10 logarithm for X. Or, for the two-argument version, return the base-B logarithm of X.Compatibility note: SQLite works like PostgreSQL in that the LOG() function computes a base-10 logarithm. Most other SQL database engines compute a natural logarithm for LOG(). In the two-argument version of LOG(B,X), the first argument is the base and the second argument is the operand. This is the same as in PostgreSQL and MySQL, but is reversed from SQL Server which uses the second argument as the base and the first argument as the operand. ",
"LOG2(X) \nDescription: Return the logarithm base-2 for the number X. ",
"MOD(X,Y) \nDescription: Return the remainder after dividing X by Y. This is similar to the '%' operator, except that it works for non-integer arguments. ",
"PI() \nDescription: Return an approximation for π. ",
"POW(X,Y) or POWER(X,Y) \nDescription: Compute X raised to the power Y. ",
"RADIANS(X) \nDescription: Convert X from degrees into radians. ",
"SIN(X) \nDescription: Return the sine of X. X is in radians. ",
"SINH(X) \nDescription: Return the hyperbolic sine of X. ",
"SQRT(X) \nDescription: Return the square root of X. NULL is returned if X is negative. ",
"TAN(X) \nDescription: Return the tangent of X. X is in radians. ",
"TANH(X) \nDescription: Return the hyperbolic tangent of X. ",
"TRUNC(X) \nDescription: Return the representable integer in between X and 0 (inclusive) that is furthest away from zero. Or, in other words, return the integer part of X, rounding toward zero. The TRUNC() function is similar toceiling(X)andfloor(X)except that it always rounds toward zero whereas ceiling(X) and floor(X) round up and down, respectively. ",
"DATE(time-value, modifier, modifier, ...) \nDescription: Returns the date as text in this format: YYYY-MM-DD. ",
"TIME(time-value, modifier, modifier, ...) \nDescription: Returns the time as text in formatted as HH:MM:SS or as HH:MM:SS.SSS if the subsec modifier is used. ",
"DATETIME(time-value, modifier, modifier, ...) \nDescription: Returns the date and time formatted as YYYY-MM-DD HH:MM:SS or as YYYY-MM-DD HH:MM:SS.SSS if the subsec modifier is used. ",
"JULIANDAY(time-value, modifier, modifier, ...) \nDescription: Returns the Julian day - the fractional number of days since noon in Greenwich on November 24, 4714 B.C. (Proleptic Gregorian calendar). ",
"UNIXEPOCH(time-value, modifier, modifier, ...) \nDescription: Returns a unix timestamp - the number of seconds since 1970-01-01 00:00:00 UTC. The UNIXEPOCH() function normally returns an integer number of seconds, but with the optional subsec modifier it will return a floating point number which is the fractional number of seconds. ",
"STRFTIME(format, time-value, modifier, modifier, ...) \nDescription: Returns the date formatted according to the format string specified as the first argument. The format string supports the most common substitutions found in the STRFTIME() function from the standard C library plus two new substitutions, %f and %J. ",
"TIMEDIFF(time-value, time-value) \nDescription: Returns a string that describes the amount of time that must be added to B in order to reach time A. The format of the TIMEDIFF() result is designed to be human-readable. "
]
def _sql_func_template(self, sql_funcs: str) -> str:
template = """### SQL Functions
You may consider one or more of the following SQL functions while generating the query:
{sql_funcs}
Important tips:
Except for the functions listed above, you may use any other functions as long as they conform to the syntax of the database engine.
"""
return template.format(sql_funcs=sql_funcs)
def _insert_stmts_template(self, insert_statements: str) -> str:
template = '''### INSERT INTO Statements
Below are several `INSERT INTO` statements. Use these to help generate predicates (i.e., `WHERE` clauses) in your SQL query:
{insert_statements}
'''
return template.format(insert_statements=insert_statements)
def _sql_synthesis_prompt(self, schema_str: str, sql_function_prompt: str, db_value_prompt: str, complexity: str, criterion: str, db_engine: str, column_count: int, embedding_model: str) -> str:
template = '''**Task Overview**
Create an executable VecSQL query based on the provided information.
**Database Schema**
{schema_str}
{sql_function_prompt}
{db_value_prompt}
**SQL Query Complexity**
Ensure the SQL query matches the {complexity} level, defined as follows:
{criterion}
**Output Format Requirements**
Enclose the SQL query in a code block:
```sql
-- Your SQL query here
```
**SQL Query Requirements**
0. (MANDATORY) The generated SQL query MUST be a VecSQL query. It MUST contain a vector similarity search using the `MATCH lembed(...)` syntax. This search MUST be performed on a column that ends with `_embedding`.
1. Use the syntax specific to the {db_engine} database engine.
2. Incorporate advanced functions if appropriate, but they are not mandatory.
3. Address real-world data analysis needs. Avoid trivial or nonsensical queries.
4. (Very important) Ensure the final SQL query selects {column_count} columns.
5. (Very important) Always verify that every column name you reference in a query exists in the tables you're querying. Before executing a query, ensure all referenced columns (e.g., column1, table1.id) are valid and spelled correctly.
**SQL extension**
Extension: sqlite_vec and sqlite_lembed.
There are a few Requirements you should Comply with in addition, else you can ignore these requirements below.
1. When generating SQL queries, you should prioritize utilizing KNN searches whenever contextually appropriate. However, you have to avoid unnecessary/forced KNN implementations for:
--Traditional relational data queries (especially for columns like: id, age, price)
--Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate
2. Only vector type(like: float[?]) support KNN queries and the name of vector column often end with "_embedding". So, you can use knn queries to search when the column name you need to search for ends with "_embedding" or when the column name with "_embedding" is also in the list.
3. In any complexity level, you can choose to use KNN queries if need.
4. When using KNN queries, you have to add LIMIT or 'And k = ?' constraint but do not use them all in the same statement. This rule is very important, do not forget to add LIMIT or 'And k = ?' constraint after MATCH operator.
5. The lembed function is used to transform a string into a vector, whose type and size match the corresponding column type in the data table. The function has two parameters, the first parameter is the name of the embedding model used (default value: {embedding_model}), and the second parameter is the content of the string type you want to convert. So, you should generate some words or sentences with specific semantic information based on name, type and comment of this column. For example, you can generate "The Daily Grind Coffee Shop\n 456 Oak Avenue\n Springfield, IL 62704\n USA" when this column name is Location_embedding, column type is float[384] and comment of column is "the embedding of location".
6. The lembed function's second parameter MUST be a SPECIFIC semantic description.
- For location_embedding: Generate REAL addresses (e.g. "Stadium: 123 Main St, Boston, MA. Capacity: 50,000. Home team: Patriots")
- For columns containing semantically meaningful data (e.g., descriptions), generate rich, contextually appropriate information. For columns without meaningful content (e.g., placeholder names), avoid creating semantically dense output to facilitate fuzzy matching operations.
- For name_embedding: You should generate variationsof the original names (e.g., altered spellings, phonetic approximations, or intentionally obfuscated words/characters) to enable Subsequent fuzzy matching to identify semantically similar names. Importantly, never generate redundant information. For example, you can generate "Lige", but do not generate "Ligand Lige", "Ligand example name", "Ligand similar to Aspirin" and "Ligand name variation".
Examples:
✅ Correct:
name_embedding MATCH lembed('all-MiniLM-L6-v2', "Kri")
❌ Wrong:
name_embedding MATCH lembed('all-MiniLM-L6-v2', "A leading publisher based in Germany specializing in
scientific journals and books.")
- For text_embedding: Use ACTUAL and meaningful sentences (e.g. "Harper Lee's To Kill a Mockingbirdis a timeless exploration of racial injustice and moral growth, seen through the innocent yet perceptive eyes of Scout Finch. With its powerful themes, unforgettable characters like Atticus Finch, and Lee's poignant prose, the novel remains a searing critique of society's failures and a testament to the courage of standing for what's right.")
- NEVER use vague words and generic phrases like "a book review"
Examples:
✅ Correct:
lembed('all-MiniLM-L6-v2', "To Kill a Mockingbird")
❌ Wrong:
lembed('all-MiniLM-L6-v2', "name of a famous book")
7. When using MATCH, please fill in a vector using function lembed after MATCH that matches the column type (with the same dimension and type). Using details are in examples.
8. The distance column is an implicitly generated metric that appears when performing vector similarity searches (using the MATCH operator) in SQLite vector extensions like sqlite-vec. If using JOIN operator, you have to clarify which table that distance belongs to.
9. A SELECT statement should have no more than one MATCH operation. However, each subquery within a SELECT statement could also have no more than one MATCH operation, independent of the parent query."
10. When performing a KNN/vector similarity search (e.g., using MATCH or lembed), always specify a LIMIT or k=N constraint directly on the vector search operation, even if the outer query already has a LIMIT. The vector search requires its own result cap to avoid ambiguity in ranking and performance issues.
11. When both LIMIT and k operations are available for vector search queries, prioritize using k operation for Broader Compatibility.
Key Points:
--Vector search needs its own LIMIT/k - The outer LIMIT applies to the final filtered results, not the initial similarity search.
--LIMIT operator should follow closely after "ORDER BY distance".
❌ Wrong Example:
```sql
SELECT a.codec_name
FROM audio_codecs a
JOIN quality_levels q ON a.codec_id = q.quality_id
WHERE a.description_embedding MATCH lembed('all-MiniLM-L6-v2', "High efficiency audio codec with low latency and optimal bandwidth")
AND q.quality_name = 'HighQuality'
LIMIT 1;
```
✅ Correct Example:
```sql
SELECT a.codec_name
FROM audio_codecs a
JOIN quality_levels q ON a.codec_id = q.quality_id
WHERE a.description_embedding MATCH lembed('all-MiniLM-L6-v2', "High efficiency audio codec with low latency and optimal bandwidth") LIMIT 1
AND q.quality_name = 'HighQuality';
```
--When using JOIN operations, you need to ensure that k does not cause ambiguity in the query. In most cases, the k parameter logically belongs to the same table as the column used in the MATCH clause. So, when the column referenced in the MATCH clause includes a table qualifier (e.g., table1.embedding), the k parameter must be explicitly bound to the same table.
❌ Wrong Example:
```sql
SELECT s.stock_id, s.symbol
FROM stocks s
JOIN exchanges e ON s.exchange_id = e.exchange_id
WHERE s.sector_embedding MATCH lembed('all-MiniLM-L6-v2', "Tech industry sector in the USA")
AND e.country = 'USA'
AND k = 5
ORDER BY s.stock_id;
```
✅ Correct Example:
```sql
SELECT s.stock_id, s.symbol
FROM stocks s
JOIN exchanges e ON s.exchange_id = e.exchange_id
WHERE s.sector_embedding MATCH lembed('all-MiniLM-L6-v2', "Tech industry sector in the USA")
AND e.country = 'USA'
AND s.k = 5
ORDER BY s.stock_id;
```
12. ​Avoids runtime errors - Many vector databases (e.g., SQLite with sqlite-vss, pgvector) enforce this requirement strictly.
13. Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries, not on other columns.
***Example of KNN queries of sqlite-vec***
first example(type of vector_embedding is float[384]):
```sql
SELECT rowid, distance
FROM vec_table
WHERE vector_embedding MATCH lembed({embedding_model},"vector of sun")
ORDER BY distance
LIMIT 1;
```
second example(type of sentence_embedding is float[384]):
```sql
select
movie_id,
title,
genre,
num_reviews,
mean_rating,
distance
from vec_movies
where sentence_embedding match lembed({embedding_model},"This is a great movie!")
and genre = 'scifi'
and num_reviews between 100 and 500
and mean_rating > 3.5
and k = 5;
```
third example(type of vector_embedding is float[384]):
```sql
select rowid, name1, name2, age, vec_to_json
from v
where vector_embedding match lembed({embedding_model},"aaa and xxx are good friends, whose age is 18.")
and k = 1
and name1 in ('alex', 'brian', 'craig')
and name2 in ('Rick', 'Morty')
and age in (21, 18);
```
**Answer**
Let's proceed step by step.
'''
return template.format(
schema_str=schema_str,
sql_function_prompt=sql_function_prompt.strip(),
db_value_prompt=db_value_prompt.strip(),
complexity=complexity,
criterion=criterion.strip(),
db_engine=db_engine,
column_count=column_count,
embedding_model=embedding_model
)
def build_prompt(self, insert_statements: List[str], create_statements: List[str], db_engine: str) -> tuple[str, str]:
random.seed(42)
complexity = random.sample(list(self.complexity2criterion_vec.keys()), 1)[0]
if len(insert_statements) == 0:
db_value_prompt = ""
else:
if len(insert_statements) > 4:
insert_statements = random.sample(insert_statements, 4)
db_value_prompt = self._insert_stmts_template(
insert_statements="\n\n".join(insert_statements)
)
function_num = random.randint(0, 2)
if function_num == 0:
sql_function_prompt = "### SQL Functions\nYou can use any function supported by the database engine."
else:
sql_funcs = ""
sampled_functions = random.sample(self.functions, min(function_num, len(self.functions)))
for idx, func in enumerate(sampled_functions):
sql_funcs += f"Function {idx + 1}:\n{func.strip()}\n"
sql_function_prompt = self._sql_func_template(sql_funcs=sql_funcs)
column_count = np.random.geometric(0.6, 1)[0]
prompt = self._sql_synthesis_prompt(
schema_str="\n\n".join(create_statements),
sql_function_prompt=sql_function_prompt.strip(),
db_value_prompt=db_value_prompt.strip(),
complexity=complexity,
criterion=self.complexity2criterion_vec[complexity].strip(),
db_engine=db_engine,
column_count=column_count,
embedding_model="\'all-MiniLM-L6-v2\'"
)
return prompt
@PROMPT_REGISTRY.register()
class Text2SQLQuestionGeneratorPrompt(PromptABC):
def __init__(self):
self.style2desc = {
"Formal": '''**Formal Style**
- Uses standard grammar and vocabulary.
- Example: Find all students older than 18 years and return their home addresses.''',
"Colloquial": '''**Colloquial Style**
- Employs informal vocabulary and expressions.
- Example: Hey! Could you help me find all the students who are over 18? I'd love to know their names and where they live.''',
"Imperative": '''**Imperative Style**
- Uses command or directive sentences.
- Example: Could you please gather all the students who are older than 18? I really need to know their names and where they live!''',
"Interrogative": '''**Interrogative Style**
- Uses question forms.
- Example: Could you tell me which students are older than 18 and what their home addresses are?''',
"Descriptive": '''**Descriptive Style**
- Uses detailed descriptions with contextual information.
- Example: I want to know the names and home addresses of all students older than 18.''',
"Concise": '''**Concise Style**
- Use short sentences.
- Example: Students older than 18, return their names and addresses.''',
"Vague": '''**Vague Style**
- Includes ambiguous vocabulary requiring inference.
- Example: What are the names and addresses of those older students? (External Knowledge: 'older students' refers to age >= 18.)''',
"Metaphorical": '''**Metaphorical Style**
- Uses metaphors or metaphorical expressions.
- Example: Find the names and addresses of those who have reached adulthood. (External Knowledge: 'reached adulthood' refers to age >= 18.)'''
}
self.steps_wo_ek = '''1. **Explain the SQL Query:** Provide a detailed explanation of what the query does.
2. **Generate a Question:** Formulate a natural language question based on the SQL query and explanation.'''
self.steps_w_ek = '''1. **Explain the SQL Query:** Provide a detailed explanation of what the query does.
2. **Generate a Question:** Formulate a natural language question based on the SQL query and explanation.
3. **External Knowledge:** For Vague or Metaphorical styles, include external knowledge to enhance clarity.'''
self.steps_multi_round = '''1. **Explain the SQL Query:** Provide a detailed explanation of what the query does.
2. **Generate a Dialogue:** Create a conversation between the User and the Assistant based on the SQL query and its explanation.'''
self.guidelines_wo_ek = '''1. Clearly describe the columns being selected by the SQL query. For example:
- "SELECT * ... FROM ..." means "Find all ...";
- "SELECT f.check_date, f.status, f.remarks, c.year, c.year_min, c.year_max, c.year_average, c.data_quality_score FROM ..." means "Return the check dates, statuses, remarks, years, minimum years, maximum years, average years, and quality scores for ...".
2. Ensure the natural language question accurately captures the semantics of the SQL query, including conditions such as predicates, `ORDER BY`, and `LIMIT` clauses.'''
self.guidelines_w_ek = '''1. Clearly describe the columns being selected by the SQL query. For example:
- "SELECT * ... FROM ..." means "Find all ...";
- "SELECT f.check_date, f.status, f.remarks, c.year, c.year_min, c.year_max, c.year_average, c.data_quality_score FROM ..." means "Return the check dates, statuses, remarks, years, minimum years, maximum years, average years, and quality scores for ...".
2. Ensure the natural language question accurately captures the semantics of the SQL query, including conditions such as predicates, `ORDER BY`, and `LIMIT` clauses.
3. If necessary, incorporate external knowledge using multiple entries separated by semicolons (";"). These can include formulas, common sense, domain-specific knowledge, or extended context, such as information from long documents. Each entry should be concise.'''
self.guidelines_multi_round = '''1. Clearly describe the columns being selected by the SQL query. For example:
- "SELECT * ... FROM ..." means "Find all ...";
- "SELECT f.check_date, f.status, f.remarks, c.year, c.year_min, c.year_max, c.year_average, c.data_quality_score FROM ..." means "Return the check dates, statuses, remarks, years, minimum years, maximum years, average years, and quality scores for ...".
2. Ensure the conversation accurately captures the semantics of the SQL query, including conditions such as predicates, `ORDER BY`, and `LIMIT` clauses.'''
self.output_format_wo_ek = '''Please structure your response as follows:
[EXPLANATION-START]
(SQL Explanation)
[EXPLANATION-END]
[QUESTION-START]
(Natural Language Question)
[QUESTION-END]
- **SQL Explanation**: Provide a clear and detailed explanation of the SQL query, enclosed within [EXPLANATION-START] and [EXPLANATION-END].
- **Natural Language Question**: Translate the SQL query into a natural language question, enclosed within [QUESTION-START] and [QUESTION-END].'''
self.output_format_w_ek = '''Please structure your response as follows:
[EXPLANATION-START]
(SQL Explanation)
[EXPLANATION-END]
[QUESTION-START]
(Natural Language Question)
[QUESTION-END]
[EXTERNAL-KNOWLEDGE-START]
(External Knowledge)
[EXTERNAL-KNOWLEDGE-END]
- **SQL Explanation**: Provide a clear and detailed explanation of the SQL query, enclosed within [EXPLANATION-START] and [EXPLANATION-END].
- **Natural Language Question**: Translate the SQL query into a natural language question, enclosed within [QUESTION-START] and [QUESTION-END].
- **External Knowledge**: Include any relevant external knowledge if applicable, enclosed within [EXTERNAL-KNOWLEDGE-START] and [EXTERNAL-KNOWLEDGE-END]. Leave this section blank if not needed.'''
self.output_format_multi_round = '''Please structure your response as follows:
[EXPLANATION-START]
(SQL Explanation)
[EXPLANATION-END]
[QUESTION-START]
(Natural Language Question, in the format of [{"User": ...}, {"Assistant": ...}, {"User": ...}, ....])
[QUESTION-END]
- **SQL Explanation**: Provide a clear and detailed explanation of the SQL query, enclosed within [EXPLANATION-START] and [EXPLANATION-END].
- **Natural Language Question**: Convert the SQL query into a multi-round dialogue, enclosed within [QUESTION-START] and [QUESTION-END]. Represent this as a list that captures multiple rounds of conversation between the User and the Assistant.'''
self.instruction_wo_ek = "Based on the above information, follow the reasoning steps to generate the explanation and the question corresponding to the SQL query."
self.instruction_w_ek = "Based on the above information, follow the reasoning steps to generate the explanation, the question, and the external knowledge corresponding to the SQL query."
self.instruction_multi_round = "Based on the above information, follow the reasoning steps to generate the explanation and the dialogue corresponding to the SQL query."
def _question_synthesis_prompt(self, style_desc, engine, column_info, sql, steps, guidelines, output_format, instruction):
template = '''**Task Overview**
Your task is to create a high-quality natural language question based on a given SQL query and other information.
**Style**
The natural language question should follow this style:
{style_desc}
**Database Engine**
{engine}
**Column Information**
Below are column names and their corresponding descriptions:
{column_info}
**SQL Query**
Given SQL query:
```sql
{sql}
```
**Reasoning Steps**
{steps}
**Guidelines**
{guidelines}
**Output Format**
{output_format}
**Insturction**
{instruction}
'''
return template.format(
style_desc = style_desc,
engine = engine,
column_info = column_info,
sql = sql,
steps = steps,
guidelines = guidelines,
output_format = output_format,
instruction = instruction
)
def build_prompt(self, sql, db_id, db_id2column_info, db_type) -> str:
random.seed(42)
styles = ["Formal", "Colloquial", "Imperative", "Interrogative", "Descriptive", "Concise", "Vague", "Metaphorical"]
style_name = random.sample(styles, 1)[0]
column_name2column_desc = db_id2column_info[db_id]
used_column_name2column_desc = dict()
for column_name, column_desc in column_name2column_desc.items():
if column_name.lower() in sql.lower():
used_column_name2column_desc[column_name] = column_desc
if style_name in ["Vague", "Metaphorical"]:
steps = self.steps_w_ek
guidelines = self.guidelines_w_ek
instruction = self.instruction_w_ek
output_format = self.output_format_w_ek
else:
steps = self.steps_wo_ek
guidelines = self.guidelines_wo_ek
instruction = self.instruction_wo_ek
output_format = self.output_format_wo_ek
prompt = self._question_synthesis_prompt(
style_desc=self.style2desc[style_name].strip(),
engine=db_type,
column_info=json.dumps(used_column_name2column_desc, indent=2, ensure_ascii=False).strip(),
sql=sql.strip(),
steps=steps.strip(),
guidelines=guidelines.strip(),
output_format=output_format.strip(),
instruction=instruction.strip()
)
return prompt
@PROMPT_REGISTRY.register()
class Text2VecSQLQuestionGeneratorPrompt(PromptABC):
def __init__(self):
pass
def _get_style2desc(self):
template = {
"Formal": '''**Formal Style**
- Uses standard grammar and vocabulary.
- Example: Find all students older than 18 years and return their home addresses.
- Vector Example: Find the three articles most closely related to Stable Diffusion and return them.''',
"Colloquial": '''**Colloquial Style**
- Employs informal vocabulary and expressions.
- Example: Hey! Could you help me find all the students who are over 18? I'd love to know their names and where they live.
- Vector Example: Hey there! Can you grab me the top 3 articles that are most closely related to Stable Diffusion?''',
"Imperative": '''**Imperative Style**
- Uses command or directive sentences.
- Example: Could you please gather all the students who are older than 18? I really need to know their names and where they live!
- Vector Example: Please find the three articles most closely related to Stable Diffusion and return their name.''',
"Interrogative": '''**Interrogative Style**
- Uses question forms.
- Example: Could you tell me which students are older than 18 and what their home addresses are?
- Vector Example: Could you show me the 3 articles that most have to do with Stable Diffusion?''',
"Descriptive": '''**Descriptive Style**
- Uses detailed descriptions with contextual information.
- Example: I want to know the names and home addresses of all students older than 18.
- Vector Example: I need to find articles that most closely related to Stable Diffusion, returning the top 3 matches sorted by cosine similarity.''',
"Concise": '''**Concise Style**
- Use short sentences.
- Example: Students older than 18, return their names and addresses.
- Vector Example: Top 3 related articles to Stable Diffusion.''',
"Vague": '''**Vague Style**
- Includes ambiguous vocabulary requiring inference.
- Example: What are the names and addresses of those older students? (External Knowledge: 'older students' refers to age >= 18.)
- Vector Example: Find a few articles have to do with Stable Diffusion. (External Knowledge: 'a few' refers to vector similarity search with k=3 limit)''',
"Metaphorical": '''**Metaphorical Style**
- Uses metaphors or metaphorical expressions.
- Example: Find the names and addresses of those who have reached adulthood. (External Knowledge: 'reached adulthood' refers to age >= 18.)
- Vector Example: Find a few articles have to do with SD in ai. (External Knowledge: 'SD in ai' refers to Stable Diffusion)''',
"Multi-turn Dialogue": '''**Multi-turn Dialogue Style**
- This involves a dialogue to clarify the user's query needs.
- Example: [{"User": "I want to query some student information."}, {"Assistant": "Which students' information would you like to query?"}, {"User": "Students older than 18."}, {"Assistant": "What other information would you like to know about them?"}, {"User": "Names and addresses."}, {"Assistant": "Is there anything else you need?"}, {"User": "No."}, {"Assistant": "OK, I will help you translate your request into an SQL query."}]
- Vector Example:
User: "I'm looking for some articles."
Assistant: "How many articles would you like to find and What field of paper are you looking for?"
User: "About 3, and they are related to Stable Diffusion."
Assistant: "I'll search for 3 articles that most closely related to Stable Diffusion."'''
}
return template
def _get_steps_wo_ek(self):
template = '''1. **Explain the SQL Query:** Provide a detailed explanation of what the query does, including any vector search operations.
2. **Generate a Question:** Formulate a natural language question based on the SQL query and explanation.'''
return template
def _get_steps_w_ek(self):
template = '''1. **Explain the SQL Query:** Provide a detailed explanation of what the query does, including any vector search operations.
2. **Generate a Question:** Formulate a natural language question based on the SQL query and explanation.
3. **External Knowledge:** For Vague or Metaphorical styles, include external knowledge to enhance clarity, especially for vector operations.'''
return template
def _get_steps_multi_round(self):
template = '''1. **Explain the SQL Query:** Provide a detailed explanation of what the query does, including any vector search operations.
2. **Generate a Dialogue:** Create a conversation between the User and the Assistant based on the SQL query and its explanation, ensuring vector operations are properly discussed.'''
return template
def _get_guidelines_wo_ek(self):
template = '''1. Clearly describe the columns being selected by the SQL query. For example:
- "SELECT * ... FROM ..." means "Find all ...";
- "SELECT f.check_date, f.status, f.remarks, c.year, c.year_min, c.year_max, c.year_average, c.data_quality_score FROM ..." means "Return the check dates, statuses, remarks, years, minimum years, maximum years, average years, and quality scores for ...".
- "SELECT rowid, vec FROM vec_table WHERE vec MATCH lembed(_,"xxx") ORDER BY distance LIMIT 2;" means "Return two of the rowid and vec that most related to xxx from vec_table, ordered by similarity distance".
- "SELECT rowid, vec FROM vec_table WHERE vec MATCH lembed(_,"xxx") AND k = 2;" means "Return two of the rowid and vec that most related to xxx from vec_table, ordered by similarity distance".
- For vector searches: Always mention the LIMIT value or K value when explaining MATCH operations.
2. Ensure the natural language question accurately captures:
- All conditions including vector similarity searches
- ORDER BY clauses (especially for distance/similarity)
- LIMIT and K clauses
- Any window functions or complex joins'''
return template
def _get_guidelines_w_ek(self):
template = '''1. Clearly describe the columns being selected by the SQL query (same as above).
2. Ensure the natural language question captures all query semantics (same as above).
3. For vector searches, include these common external knowledge points:
- "MATCH" operator performs approximate nearest neighbor (ANN) search;
- "k=N" specifies the number of similar items to return;
- Vectors are compared using Euclidean distance (L2 norm) by default;
- Similarity increases as distance decreases;
- Include any domain-specific knowledge about the vector meaning.'''
return template
def _get_guidelines_multi_round(self):
template = '''1. Clearly describe the columns being selected by the SQL query (same as above).
2. Ensure the dialogue naturally covers:
- The purpose of the vector search;
- How many similar items are needed (LIMIT);
- What the target vector represents;
- Any additional filtering or sorting requirements.'''
return template
def _get_vector_question_guidelines(self):
template = '''
**Guiding Principles for High-Quality Vector Questions:**
Your ultimate goal is to create a question that a real human would ask. To do this, you must internalize the following principles:
1. **Translate Mechanism into Intent (The Golden Rule!)**: A vector search (`MATCH ... LIMIT N`) is a technical mechanism to find the "top N" or "best N" examples of something. Your question **MUST** reflect the user's *intent*, not the mechanism.
* **Prohibited Phrases**: Avoid technical jargon that describes the search process. Do NOT use phrases like: "most closely related to", "semantically similar to", "based on similarity", "concept of", "field of".
* **Approved Phrasing**: Instead, use natural language that implies ranking or quality. Use words like: "top 5", "best", "most representative", "leading", or simply state the entity directly. For example, a search for the top 5 professors should be phrased as "Who are the top 5 professors?", not "Who are the 5 professors most similar to the concept of a professor?".
2. **Identify and Preserve Key Entities**: Within the `lembed()` text, identify the core keywords (e.g., "Professor", "Computer Science", "leadership skills"). These **MUST** be present in the final question to ensure it is specific and meaningful.
3. **Rephrase Naturally, Do Not Copy Verbatim**: While preserving key entities, change the overall sentence structure to fit the requested style (Formal, Colloquial, etc.). Do not copy the entire `lembed()` string word-for-word.
---
**Examples of Correct vs. Incorrect Behavior:**
**Example 1: Being Natural--like a real human would ask**
* **Input VecSQL**: `... WHERE p.hasPosition_embedding MATCH lembed('all-MiniLM-L6-v2', "Professor of Computer Science") AND k = 5 ...`
* **BAD Question**: `"Identify five professors whose roles most closely match the concept of teaching computer science at a professorial level..."` or `"Please provide the IDs, names, and course levels for the five professors who have positions most closely related to the field of computer science, ordered by similarity."`
* **Reasoning**: This is the classic mistake. It describes the vector search mechanism ("most closely related to") instead of asking a direct, human-like question. Too verbose and abstract. "concept of teaching computer science at a professorial level" or "who have positions most closely related to" is an unnatural way to say "Computer Science Professor".
* **GOOD Question**: `"Please provide the IDs and course levels for the 5 professors who specialize in computer science."` or `"Identify five computer science professors and list the levels of the courses they teach."`
* **Reasoning**: Correctly uses the key entities "Professor" and "Computer Science" in a formal and direct manner.
**Example 2: Preserving Entities**
* **Input VecSQL**: `... WHERE p.hasPosition_embedding MATCH lembed('all-MiniLM-L6-v2', "Professor of Mathematics") LIMIT 5 ...`
* **BAD Question**: `"Could you please find the top 5 individuals most semantically related to a specialized academic teaching role..."`
* **Reasoning**: Completely fails by losing the essential key entities **"Professor"** and **"Mathematics"**. The question is now uselessly vague.
* **GOOD Question**: `"Get me the top 5 people who are very professional in mathematics, and show me their course levels."`
* **Reasoning**: Preserves the critical entities in a natural, imperative sentence.
**Example 3: Being Natural**
* **Input VecSQL**: `... WHERE performance_embedding MATCH lembed('all-MiniLM-L6-v2', "exceptional performance with leadership skills") LIMIT 1;`
* **BAD Question**: `"Hey, could you help me find the employee whose performance is most closely related to being a standout leader?"`
* **Reasoning**: "most closely related to being..." is clunky and not like a real human would ask. It also loses the "exceptional performance" aspect.
* **GOOD Question**: `"Hey, can you find the employee who performance very well and with great leadership ability'? I need their SSN."` or `"Who's our top employee showing both great performance and leadership? Grab their SSN for me."`
* **Reasoning**: Sounds like a real person talking and naturally incorporates the key concepts.
---
'''
def _get_output_format_wo_ek(self):
template = '''Please structure your response as follows:
[EXPLANATION-START]
(SQL Explanation)
[EXPLANATION-END]
[QUESTION-START]
(Natural Language Question)
[QUESTION-END]
- **SQL Explanation**: Provide a clear and detailed explanation of the SQL query, enclosed within [EXPLANATION-START] and [EXPLANATION-END].
- **Natural Language Question**: Translate the SQL query into a natural language question, enclosed within [QUESTION-START] and [QUESTION-END].'''
return template
def _get_output_format_w_ek(self):
template = '''Please structure your response as follows:
[EXPLANATION-START]
(SQL Explanation)
[EXPLANATION-END]
[QUESTION-START]
(Natural Language Question)
[QUESTION-END]
[EXTERNAL-KNOWLEDGE-START]
(External Knowledge)
[EXTERNAL-KNOWLEDGE-END]
- **SQL Explanation**: Provide a clear and detailed explanation of the SQL query, enclosed within [EXPLANATION-START] and [EXPLANATION-END].
- **Natural Language Question**: Translate the SQL query into a natural language question, enclosed within [QUESTION-START] and [QUESTION-END].
- **External Knowledge**: Include any relevant external knowledge if applicable, enclosed within [EXTERNAL-KNOWLEDGE-START] and [EXTERNAL-KNOWLEDGE-END]. Leave this section blank if not needed.'''
return template
def _get_output_format_multi_round(self):
template = '''Please structure your response as follows:
[EXPLANATION-START]
(SQL Explanation)
[EXPLANATION-END]
[QUESTION-START]
(Natural Language Question, in the format of [{"User": ...}, {"Assistant": ...}, {"User": ...}, ....])
[QUESTION-END]
- **SQL Explanation**: Provide a clear and detailed explanation of the SQL query, enclosed within [EXPLANATION-START] and [EXPLANATION-END].
- **Natural Language Question**: Convert the SQL query into a multi-round dialogue, enclosed within [QUESTION-START] and [QUESTION-END]. Represent this as a list that captures multiple rounds of conversation between the User and the Assistant.'''
return template
def _get_instruction_wo_ek(self):
template = '''Based on the above information:
1. Analyze the SQL query, paying special attention to any vector operations
2. Generate a clear explanation covering all query elements
3. Formulate a precise natural language question
4. Verify all vector operations (MATCH, LIMIT, ORDER BY distance) or (MATCH, And k = ?) are properly represented'''
return template
def _get_instruction_w_ek(self):
template = '''Based on the above information:
1. Analyze the SQL query, especially vector operations
2. Generate explanation covering all elements
3. Formulate precise question
4. Add relevant external knowledge about vector operations
5. Verify all vector elements are properly represented'''
return template
def _get_instruction_multi_round(self):
template = '''Based on the above information:
1. Analyze the SQL query, especially vector operations
2. Generate explanation covering all elements
3. Create natural dialogue that explores vector search parameters
4. Ensure LIMIT, target vector and distance sorting are discussed'''
return template
def _question_synthesis_prompt(self, using_knn, style_desc, engine, extension, column_info, sql, steps, guidelines, output_format, instruction):
template = '''**Task Overview**
Your task is to create a high-quality natural language question based on a given SQL query and other information.
{using_knn}
**Style**
The natural language question should follow this style:
{style_desc}
**Database Engine**
{engine}
**Database Extension**
{extension}
**Column Information**
Below are column names and their corresponding descriptions:
{column_info}
**SQL Query**
Given SQL query:
```sql
{sql}
```
**Reasoning Steps**
{steps}
**Guidelines**
{guidelines}
**Output Format**
{output_format}
**Insturction**
{instruction}
'''
return template.format(
using_knn = using_knn,
style_desc = style_desc,
engine = engine,
extension = extension,
column_info = column_info,
sql = sql,
steps = steps,
guidelines = guidelines,
output_format = output_format,
instruction = instruction
)
def build_prompt(self, input_sql, input_db_id, db_id2column_info, db_type) -> str:
random.seed(42)
styles = ["Formal", "Colloquial", "Imperative", "Interrogative", "Descriptive", "Concise", "Vague", "Metaphorical"]
style_name = random.sample(styles, 1)[0]
column_name2column_desc = db_id2column_info[input_db_id]
used_column_name2column_desc = dict()
for column_name, column_desc in column_name2column_desc.items():
if column_name.lower() in input_sql.lower():
used_column_name2column_desc[column_name] = column_desc
if style_name in ["Vague", "Metaphorical"]:
steps = self._get_steps_w_ek()
guidelines = self._get_guidelines_w_ek()
instruction = self._get_instruction_w_ek()
output_format = self._get_output_format_w_ek()
else:
steps = self._get_steps_wo_ek()
guidelines = self._get_guidelines_wo_ek()
instruction = self._get_instruction_wo_ek()
output_format = self._get_output_format_wo_ek()
using_knn = "Extension includes sqlite-vec, you have to use KNN queries of it."
extension="sqlite_vec and sqlite_lembed"
prompt = self._question_synthesis_prompt(
using_knn=using_knn,
style_desc=self._get_style2desc()[style_name].strip(),
engine=db_type,
extension=extension,
column_info=json.dumps(used_column_name2column_desc, indent=2, ensure_ascii=False).strip(),
sql=input_sql.strip(),
steps=steps.strip(),
guidelines=guidelines.strip(),
output_format=output_format.strip(),
instruction=instruction.strip()
)
return prompt
@PROMPT_REGISTRY.register()
class SQLVariationGeneratorPrompt(PromptABC):
def __init__(self):
self.variation_type_prompts = [
'''
Data Value Transformations
- Alter filter conditions, date ranges, or numerical thresholds
- Change sorting criteria or limit values
- Modify aggregation boundaries (e.g., GROUP BY different time periods)
''',
'''Query Structure Modifications
- Convert aggregation queries to window functions or vice versa
- Change from simple queries to subqueries or CTEs
- Transform JOINs to EXISTS/IN clauses or vice versa
- Switch between correlated and non-correlated subqueries
''',
'''Business Logic Changes
- Adapt the query for different business scenarios (sales → inventory, customers → suppliers)
- Modify to handle different data granularities (daily → monthly, individual → grouped)
- Change the analytical perspective (profit analysis → cost analysis)
- Alter the metrics being calculated (sum → average, count → percentage)
''',
'''Complexity Enhancements
- Add extra filtering conditions or business rules
- Introduce additional table joins
- Include conditional logic with CASE statements
- Add data validation or quality checks
''',
'''Advanced SQL Features
- Implement complex window functions with partitioning
- Create queries requiring UNION/INTERSECT/EXCEPT operations
- Add recursive CTEs for hierarchical data
- Include pivot/unpivot operations
''',
'''Performance and Optimization
- Add performance optimization hints
- Restructure for better index usage
- Convert to more efficient query patterns
- Add appropriate WHERE clause optimizations
''',
]
def _insert_stmts_template(self, insert_statements):
template = '''### INSERT INTO Statements
Below are several `INSERT INTO` statements. Use these to help generate predicates (i.e., `WHERE` clauses) in your SQL query:
{insert_statements}
'''
return template.format(insert_statements=insert_statements)
def _sql_variation_prompt(self, original_sql, schema_str, db_value_prompt, variation_prompt, db_engine):
template = """**Task Overview**
Create a new reasonable and executable SQL query by applying the specified transformations to the original query.
**Database Engine**
{db_engine}
**Database Schema**
{schema_str}
{db_value_prompt}
**Original SQL Query**
```sql
{original_sql}
```
**Transformation Instructions**
{variation_prompt}
**Requirements**
1. The new query must be syntactically correct for {db_engine}
2. All referenced tables/columns must exist in the provided schema
3. Ensure the query is executable
**Output Format**
The transformed SQL query should be enclosed in a code block:
```sql
-- Your transformed SQL query here
```
**Answer**
Let's proceed step by step.
"""
return template.format(
variation_prompt=variation_prompt,
schema_str=schema_str,
db_value_prompt=db_value_prompt,
original_sql=original_sql,
db_engine=db_engine
)
def build_prompt(self, original_sql, create_statements, insert_statements, db_engine) -> str:
random.seed(42)
if len(insert_statements) == 0:
db_value_prompt = ""
else:
if len(insert_statements) > 4:
insert_statements = random.sample(insert_statements, 4)
db_value_prompt = self._insert_stmts_template(
insert_statements="\n\n".join(insert_statements)
)
variation_type = random.randint(0, 5)
variation_prompt = self.variation_type_prompts[variation_type]
prompt = self._sql_variation_prompt(
original_sql=original_sql,
schema_str="\n\n".join(create_statements),
db_value_prompt=db_value_prompt.strip(),
variation_prompt=variation_prompt.strip(),
db_engine=db_engine
)
return prompt
@PROMPT_REGISTRY.register()
class Text2SQLPromptGeneratorPrompt(PromptABC):
def __init__(self):
pass
def build_prompt(self, db_details: str, question: str, evidence: str, db_engine: str) -> str:
if evidence:
question_and_evidence = f"{evidence}\n{question}"
else:
question_and_evidence = question
template = """Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
{db_engine}
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question_and_evidence}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
sql
-- Your SQL query
Take a deep breath and think step by step to find the correct SQL query.
"""
prompt = template.format(db_details=db_details, question_and_evidence=question_and_evidence, db_engine=db_engine)
return prompt
@PROMPT_REGISTRY.register()
class Text2VecSQLPromptGeneratorPrompt(PromptABC):
def __init__(self):
pass
def _get_sqlite_vec_description(self):
return """There are a few Requirements you should Comply with in addition, else you can ignore these requirements below.
1. When generating SQL queries, you should prioritize utilizing KNN searches whenever contextually appropriate. However, you have to avoid unnecessary/forced KNN implementations for:
--Traditional relational data queries (especially for columns like: id, age, price)
--Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate
2. Only vector type(like: float[?]) support KNN queries and the name of vector column often end with "_embedding". So, you can use knn queries to search when the column name you need to search for ends with "_embedding" or when the column name with "_embedding" is also in the list.
3. In any complexity level, you can choose to use KNN queries if need.
4. When using KNN queries, you have to add LIMIT or 'And k = ?' constraint but do not use them all in the same statement. This rule is very important, do not forget to add LIMIT or 'And k = ?' constraint after MATCH operator.
5. The lembed function is used to transform a string into a vector, whose type and size match the corresponding column type in the data table. The function has two parameters, the first parameter is the name of the embedding model used (default value: {embedding_model}), and the second parameter is the content of the string type you want to convert. So, you should generate some words or sentences with specific semantic information based on name, type and comment of this column. For example, you can generate "The Daily Grind Coffee Shop\n 456 Oak Avenue\n Springfield, IL 62704\n USA" when this column name is Location_embedding, column type is float[384] and comment of column is "the embedding of location".
6. The lembed function's second parameter MUST be a SPECIFIC semantic description.
- For location_embedding: Generate REAL addresses (e.g. "Stadium: 123 Main St, Boston, MA. Capacity: 50,000. Home team: Patriots")
- For columns containing semantically meaningful data (e.g., descriptions), generate rich, contextually appropriate information. For columns without meaningful content (e.g., placeholder names), avoid creating semantically dense output to facilitate fuzzy matching operations.
- For name_embedding: You should generate variationsof the original names (e.g., altered spellings, phonetic approximations, or intentionally obfuscated words/characters) to enable Subsequent fuzzy matching to identify semantically similar names. Importantly, never generate redundant information. For example, you can generate "Lige", but do not generate "Ligand Lige", "Ligand example name", "Ligand similar to Aspirin" and "Ligand name variation".
Examples:
✅ Correct:
name_embedding MATCH lembed('all-MiniLM-L6-v2', "Kri")
❌ Wrong:
name_embedding MATCH lembed('all-MiniLM-L6-v2', "A leading publisher based in Germany specializing in
scientific journals and books.")
- For text_embedding: Use ACTUAL and meaningful sentences (e.g. "Harper Lee’s To Kill a Mockingbirdis a timeless exploration of racial injustice and moral growth, seen through the innocent yet perceptive eyes of Scout Finch. With its powerful themes, unforgettable characters like Atticus Finch, and Lee’s poignant prose, the novel remains a searing critique of society’s failures and a testament to the courage of standing for what’s right.")
- NEVER use vague words and generic phrases like "a book review"
Examples:
✅ Correct:
lembed('all-MiniLM-L6-v2', "To Kill a Mockingbird")
❌ Wrong:
lembed('all-MiniLM-L6-v2', "name of a famous book")
7. When using MATCH, please fill in a vector using function lembed after MATCH that matches the column type (with the same dimension and type). Using details are in examples.
8. The distance column is an ​​implicitly generated metric​​ that appears when performing vector similarity searches (using the MATCH operator) in SQLite vector extensions like sqlite-vec. If using JOIN operator, you have to clarify which table that distance belongs to.
9. A SELECT statement should have no more than one MATCH operation. However, each subquery within a SELECT statement could also have no more than one MATCH operation, independent of the parent query."
10. When performing a KNN/vector similarity search (e.g., using MATCH or lembed), always specify a LIMIT or k=N constraint directly on the vector search operation, even if the outer query already has a LIMIT. The vector search requires its own result cap to avoid ambiguity in ranking and performance issues.
11. When both LIMIT and k operations are available for vector search queries, prioritize using k operation for ​​Broader Compatibility.
Key Points:
​--​Vector search needs its own LIMIT/k​​ – The outer LIMIT applies to the final filtered results, not the initial similarity search.
--LIMIT operator should follow closely after "ORDER BY distance".
❌ Wrong Example:
```sql
SELECT a.codec_name
FROM audio_codecs a
JOIN quality_levels q ON a.codec_id = q.quality_id
WHERE a.description_embedding MATCH lembed('all-MiniLM-L6-v2', "High efficiency audio codec with low latency and optimal bandwidth")
AND q.quality_name = 'HighQuality'
LIMIT 1;
```
✅ Correct Example:
```sql
SELECT a.codec_name
FROM audio_codecs a
JOIN quality_levels q ON a.codec_id = q.quality_id
WHERE a.description_embedding MATCH lembed('all-MiniLM-L6-v2', "High efficiency audio codec with low latency and optimal bandwidth") LIMIT 1
AND q.quality_name = 'HighQuality';
```
--When using JOIN operations, you need to ensure that k does not cause ambiguity in the query. In most cases, the k parameter logically belongs to the same table as the column used in the MATCH clause. So, when the column referenced in the MATCH clause includes a table qualifier (e.g., table1.embedding), the k parameter must be explicitly bound to the same table.
❌ Wrong Example:
```sql
SELECT s.stock_id, s.symbol
FROM stocks s
JOIN exchanges e ON s.exchange_id = e.exchange_id
WHERE s.sector_embedding MATCH lembed('all-MiniLM-L6-v2', "Tech industry sector in the USA")
AND e.country = 'USA'
AND k = 5
ORDER BY s.stock_id;
```
✅ Correct Example:
```sql
SELECT s.stock_id, s.symbol
FROM stocks s
JOIN exchanges e ON s.exchange_id = e.exchange_id
WHERE s.sector_embedding MATCH lembed('all-MiniLM-L6-v2', "Tech industry sector in the USA")
AND e.country = 'USA'
AND s.k = 5
ORDER BY s.stock_id;
```
12. ​Avoids runtime errors​​ – Many vector databases (e.g., SQLite with sqlite-vss, pgvector) enforce this requirement strictly.
13. Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries, not on other columns.
***Example of KNN queries of sqlite-vec***
first example(type of vector_embedding is float[384]):
```sql
SELECT rowid, distance
FROM vec_table
WHERE vector_embedding MATCH lembed({embedding_model},"vector of sun")
ORDER BY distance
LIMIT 1;
```
second example(type of sentence_embedding is float[384]):
```sql
select
movie_id,
title,
genre,
num_reviews,
mean_rating,
distance
from vec_movies
where sentence_embedding match lembed({embedding_model},"This is a great movie!")
and genre = 'scifi'
and num_reviews between 100 and 500
and mean_rating > 3.5
and k = 5;
```
third example(type of vector_embedding is float[384]):
```sql
select rowid, name1, name2, age, vec_to_json
from v
where vector_embedding match lembed({embedding_model},"aaa and xxx are good friends, whose age is 18.")
and k = 1
and name1 in ('alex', 'brian', 'craig')
and name2 in ('Rick', 'Morty')
and age in (21, 18);
```"""
def build_prompt(self, db_details: str, question: str, evidence: str, db_engine: str) -> str:
if evidence:
question_and_evidence = f"{evidence}\n{question}"
else:
question_and_evidence = question
template = """Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
{db_engine}
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question_and_evidence}
Database Extension:
{db_extension_description}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
sql
-- Your SQL query
Take a deep breath and think step by step to find the correct SQL query.
"""
prompt = template.format(db_details=db_details, question_and_evidence=question_and_evidence, db_engine=db_engine, db_extension_description=self._get_sqlite_vec_description())
return prompt
import os
import time
from typing import List
from dataflow.core import LLMServingABC
from dataflow import get_logger
class LocalEmbeddingServing(LLMServingABC):
"""
Generate embeddings using local sentence-transformer models.
Interface style aligns with APILLMServing_request and supports parallel computation through max_workers.
"""
def __init__(self,
model_name: str = 'all-MiniLM-L6-v2',
device: str = None,
max_workers: int = 2,
max_retries: int = 3
):
"""
Initialize local embedding model.
Args:
model_name (str): Name of the sentence-transformer model to load.
device (str): Preferred device type ('cpu' or 'cuda'). If None, will auto-detect.
max_workers (int): Number of parallel workers.
- 1 (default): No parallelism, runs on single device.
- > 1: Enable parallel mode.
- On CPU: Will enable multiprocessing mode (using all available cores).
- On multi-GPU: Will use min(max_workers, local GPU count) GPUs for data parallelism.
max_retries (int): Maximum number of retries when errors occur.
"""
self.logger = get_logger()
self.model_name = model_name
self.max_workers = max_workers
self.max_retries = max_retries
self._model = None
self._torch = None
self._SentenceTransformer = None
self._initialized = False
self._device = device
def _ensure_dependencies_available(self):
if self._torch is not None and self._SentenceTransformer is not None:
return
try:
import torch
from sentence_transformers import SentenceTransformer
self._torch = torch
self._SentenceTransformer = SentenceTransformer
except ImportError:
raise ImportError(
"The 'embedding' optional dependencies are required but not installed.\n"
"Please run: pip install 'open-dataflow[vectorsql]'"
)
def _initialize_model(self):
if self._initialized:
return
self._ensure_dependencies_available()
self._execution_strategy = "single_device"
self._target_devices = None
device = self._device
if device is None:
device = 'cuda' if self._torch.cuda.is_available() else 'cpu'
if self.max_workers > 1:
if 'cpu' in device:
cpu_count = os.cpu_count() or 1
if cpu_count > 1:
self._execution_strategy = "cpu_parallel"
self.logger.info(f"CPU parallel mode enabled (max_workers > 1).")
self.logger.warning(f"Note: sentence-transformers will use all {cpu_count} available CPU cores, "
"max_workers parameter only enables this mode, does not limit core count.")
else:
self.logger.warning("max_workers > 1 but only 1 CPU core detected, falling back to single device mode.")
elif 'cuda' in device:
gpu_count = self._torch.cuda.device_count()
if gpu_count > 1:
num_gpus_to_use = min(self.max_workers, gpu_count)
if num_gpus_to_use > 1:
self._execution_strategy = "gpu_parallel"
self._target_devices = [f'cuda:{i}' for i in range(num_gpus_to_use)]
self.logger.info(f"Multi-GPU parallel mode enabled. Will use {num_gpus_to_use} GPUs (min({self.max_workers}, {gpu_count})).")
self.logger.info(f"Target devices: {self._target_devices}")
else:
self.logger.warning("max_workers > 1 but final GPU count for computation is 1, falling back to single device mode.")
else:
self.logger.warning("max_workers > 1 but only 1 GPU detected, falling back to single device mode.")
if self._execution_strategy == "gpu_parallel":
self.primary_device = self._target_devices[0]
else:
self.primary_device = 'cuda:0' if 'cuda' in device and self._torch.cuda.is_available() else 'cpu'
self.logger.info(f"Loading model '{self.model_name}' to primary device '{self.primary_device}'...")
self._model = self._SentenceTransformer(self.model_name, device=self.primary_device)
self.logger.info("Model loaded successfully.")
self._initialized = True
@property
def model(self):
if not self._initialized:
self._initialize_model()
return self._model
def start_serving(self) -> None:
self.logger.info("LocalEmbeddingServing: No need to start independent service, model is already in memory.")
def generate_embedding_from_input(self,
texts: List[str],
batch_size: int = 32
) -> List[List[float]]:
"""
Generate embeddings for a list of texts, including retry logic and parallel execution logic.
"""
if not self._initialized:
self._initialize_model()
last_exception = None
for attempt in range(self.max_retries):
try:
self.logger.info(f"Starting to generate embeddings for {len(texts)} texts (attempt {attempt + 1}/{self.max_retries})...")
if self._execution_strategy == "gpu_parallel":
pool = self.model.start_multi_process_pool(target_devices=self._target_devices)
embeddings = self.model.encode_multi_process(texts, pool=pool, batch_size=batch_size)
self.model.stop_multi_process_pool(pool)
elif self._execution_strategy == "cpu_parallel":
pool = self.model.start_multi_process_pool()
embeddings = self.model.encode_multi_process(texts, pool=pool, batch_size=batch_size)
self.model.stop_multi_process_pool(pool)
else: # Single device mode
embeddings = self.model.encode(texts, batch_size=batch_size, show_progress_bar=True)
self.logger.info("Embedding generation successful.")
return embeddings.tolist()
except Exception as e:
last_exception = e
self.logger.error(f"Error occurred while generating embeddings (attempt {attempt + 1}): {e}")
if attempt + 1 < self.max_retries:
wait_time = 2 ** attempt
self.logger.info(f"Will retry after {wait_time} seconds...")
time.sleep(wait_time)
else:
self.logger.error("All retries failed.")
raise last_exception
return []
def cleanup(self):
if not self._initialized:
self.logger.info("Model not initialized, nothing to clean up.")
return
self.logger.info(f"Cleaning up resources for model '{self.model_name}'...")
if self._model is not None:
del self._model
self._model = None
if self._torch is not None and self._torch.cuda.is_available():
self._torch.cuda.empty_cache()
self._initialized = False
self.logger.info("Cleanup completed.")
def generate_from_input(self, user_inputs: List[str], system_prompt: str = "") -> List[str]:
self.logger.warning("generate_from_input is not applicable for LocalEmbeddingServing.")
return [None] * len(user_inputs)
def generate_from_conversations(self, conversations: List[List[dict]]) -> List[str]:
self.logger.warning("generate_from_conversations is not applicable for LocalEmbeddingServing.")
return [None] * len(conversations)
from .api_llm_serving_request import APILLMServing_request
from .local_model_llm_serving import LocalModelLLMServing_vllm
from .local_model_llm_serving import LocalModelLLMServing_sglang
from .api_vlm_serving_openai import APIVLMServing_openai
from .google_api_serving import PerspectiveAPIServing
from .lite_llm_serving import LiteLLMServing
from .localhost_llm_api_serving import LocalHostLLMAPIServing_vllm
from .localmodel_lalm_serving import LocalModelLALMServing_vllm
from .LocalSentenceLLMServing import LocalEmbeddingServing
from .light_rag_serving import LightRAGServing
from .api_google_vertexai_serving import APIGoogleVertexAIServing
from .local_model_vlm_serving import LocalVLMServing_vllm
__all__ = [
"APIGoogleVertexAIServing",
"APILLMServing_request",
"LocalModelLLMServing_vllm",
"LocalModelLLMServing_sglang",
"APIVLMServing_openai",
"PerspectiveAPIServing",
"LiteLLMServing",
"LocalModelLALMServing_vllm",
"LocalHostLLMAPIServing_vllm",
"LocalVLMServing_vllm",
]
import os
from pyexpat import model
import time
import logging
import json
import re
import uuid
import tempfile
from mimetypes import guess_type
from pathlib import Path
from typing import Any, List, Optional, Union
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataflow.core import LLMServingABC
import fsspec
from tqdm import tqdm
from pydantic import BaseModel
import pandas as pd
# --- Dependency: Google Vertex AI SDK ---
# Make sure to install the required library:
# pip install "google-cloud-aiplatform>=1.55" pydantic tqdm "pydantic-core<2" google-cloud-bigquery google-genai
try:
# NEW: Correct imports for the modern Vertex AI SDK
import vertexai
from vertexai.generative_models import (
GenerativeModel,
Part,
Tool,
FunctionDeclaration,
GenerationConfig,
GenerationResponse,
)
from google.api_core import exceptions as google_exceptions
from google.cloud import bigquery
# For batch processing
from google import genai
from google.genai.types import CreateBatchJobConfig
except ImportError:
raise ImportError(
"Google Cloud AI Platform library not found or is outdated. "
"Please run 'pip install \"google-cloud-aiplatform>=1.55\" pydantic tqdm google-cloud-bigquery google-genai'"
)
# --- Gemini Client Logic (Updated for modern Vertex AI SDK) ---
class GeminiVertexAIClient:
def __init__(self, project: Optional[str] = None, location: str = 'us-central1'):
"""Initialize Gemini client for Vertex AI."""
# Check required environment variables
google_app_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
google_cloud_location = os.getenv("GOOGLE_CLOUD_LOCATION")
google_cloud_project = os.getenv("GOOGLE_CLOUD_PROJECT")
google_genai_use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI")
self.logger = logging.getLogger(__name__)
# Validate GOOGLE_APPLICATION_CREDENTIALS
if not google_app_credentials:
raise ValueError(
"GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. "
"Please set it to the path of your service account key file, e.g.: "
"export GOOGLE_APPLICATION_CREDENTIALS=\"/path/to/your/key.json\""
)
# Check if credentials file exists
if not os.path.exists(google_app_credentials):
raise ValueError(
f"GOOGLE_APPLICATION_CREDENTIALS file not found: {google_app_credentials}. "
"Please ensure the path is correct."
)
# Log environment variable status
if google_cloud_location:
location = google_cloud_location
self.logger.info(f"Using GOOGLE_CLOUD_LOCATION from environment: {location}")
if google_cloud_project:
if project is None:
project = google_cloud_project
self.logger.info(f"Using GOOGLE_CLOUD_PROJECT from environment: {project}")
else:
self.logger.warning(
f"Project parameter '{project}' provided, but GOOGLE_CLOUD_PROJECT is also set. "
f"Using parameter value '{project}'."
)
if google_genai_use_vertexai:
self.logger.info(f"GOOGLE_GENAI_USE_VERTEXAI is set: {google_genai_use_vertexai}")
vertexai.init(project=project, location=location)
# NOTE: We remove the model instance cache because each model instance will now be
# tied to a specific system_instruction, which is dynamic.
def _prepare_content(self, content: Union[str, Path]) -> List[Part]:
"""Prepares content for the Gemini model. Always returns a list of Parts."""
if isinstance(content, (str, Path)) and os.path.exists(content) and os.path.isfile(content):
mime_type, _ = guess_type(str(content))
if not mime_type:
mime_type = "application/octet-stream"
# Using from_uri is generally more robust for Vertex AI with local files
return [Part.from_uri(uri=str(content), mime_type=mime_type)]
elif isinstance(content, str):
return [Part.from_text(content)]
else:
raise ValueError("Only support text (str) or local file path (str or Path) as input.")
def generate(
self,
system_prompt: str,
content: Union[str, Path],
model: Optional[str] = None,
temperature: float = 0.0,
max_tokens: int = 4096,
response_schema: Optional[Union[type[BaseModel], dict]] = None,
) -> GenerationResponse:
"""Generate response from a Gemini model on Vertex AI."""
model_name = model
# --- MAJOR FIX HERE ---
# The system_instruction must be passed during the model's initialization,
# not to the generate_content() method.
model_instance = GenerativeModel(
model_name,
system_instruction=system_prompt
)
generation_config = GenerationConfig(
temperature=temperature,
max_output_tokens=max_tokens,
)
tools = None
if response_schema is not None:
if isinstance(response_schema, dict):
# 已经是 JSON Schema
schema_dict = response_schema
else:
# 是 BaseModel,转换成 JSON Schema
schema_dict = response_schema.model_json_schema()
function_declaration = FunctionDeclaration(
name="extract_data",
description=f"Extracts structured data according to the provided schema.",
parameters=schema_dict,
)
tools = [Tool(function_declarations=[function_declaration])]
generation_config.response_mime_type = "application/json"
contents = self._prepare_content(content)
# --- MAJOR FIX HERE ---
# Remove the 'system_instruction' argument from this call.
response = model_instance.generate_content(
contents=contents,
generation_config=generation_config,
tools=tools,
)
return response
# --- Main Implementation: GeminiLLMServing ---
class APIGoogleVertexAIServing(LLMServingABC):
"""
LLM Serving class for Google's Gemini models via Vertex AI API.
"""
def __init__(self,
model_name: str = "gemini-2.5-flash",
project: Optional[str] = None,
location: str = 'us-central1',
max_workers: int = 10,
max_retries: int = 5,
temperature: float = 0.0,
max_tokens: int = 4096,
use_batch: bool = False,
batch_wait: bool = True,
batch_dataset: str = "dataflow_batch",
csv_filename: Optional[str] = None,
bq_csv_filename: Optional[str] = None,
):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
self.model_name = model_name
self.max_workers = max_workers
self.max_retries = max_retries
self.temperature = temperature
self.max_tokens = max_tokens
self.project = project
self.location = location
self.use_batch = use_batch
self.batch_wait = batch_wait
self.batch_dataset = batch_dataset
self.csv_filename = csv_filename
self.bq_csv_filename = bq_csv_filename
try:
self.client = GeminiVertexAIClient(project=project, location=location)
self.logger.info(f"GeminiVertexAIClient initialized successfully for model '{self.model_name}'.")
except ValueError as e:
self.logger.error(f"Failed to initialize GeminiVertexAIClient: {e}")
raise
# Initialize BigQuery client for batch processing
try:
self.bq_client = bigquery.Client(project=project)
self.logger.info("BigQuery client initialized successfully.")
except Exception as e:
self.logger.warning(f"BigQuery client initialization failed: {e}. Batch processing features may not be available.")
self.bq_client = None
# Initialize Google GenAI client for batch processing
try:
self.genai_client = genai.Client()
self.logger.info("Google GenAI client initialized successfully.")
except Exception as e:
self.logger.warning(f"Google GenAI client initialization failed: {e}. Batch processing features may not be available.")
self.genai_client = None
def start_serving(self) -> None:
self.logger.info("GeminiLLMServing: Using Google Cloud API, no local service to start.")
def cleanup(self) -> None:
self.logger.info("GeminiLLMServing: No specific cleanup actions needed for API-based client.")
def load_model(self, model_name_or_path: str, **kwargs: Any):
self.logger.info(f"Switching model from '{self.model_name}' to '{model_name_or_path}'.")
self.model_name = model_name_or_path
def _generate_single_with_retry(self, index: int, user_input: str, system_prompt: str, response_schema: Optional[Union[type[BaseModel], dict]] = None) -> tuple[int, Optional[str]]:
"""Generates a response for a single input with a retry mechanism."""
for attempt in range(self.max_retries):
try:
response = self.client.generate(
system_prompt=system_prompt,
content=user_input,
model=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
response_schema=response_schema,
)
# NEW: Robust response parsing for both text and function calls
if not response.candidates:
finish_reason = response.prompt_feedback.block_reason.name if response.prompt_feedback else "Unknown"
self.logger.warning(
f"Request {index} was blocked or produced no candidates. Reason: {finish_reason}. Attempt {attempt + 1}/{self.max_retries}."
)
if attempt == self.max_retries - 1:
return index, f"Error: Content blocked by API. Reason: {finish_reason}"
time.sleep(2 ** attempt)
continue
candidate = response.candidates[0]
# Check for safety blocks or other stop reasons
if candidate.finish_reason.name not in ["STOP", "MAX_TOKENS"]:
self.logger.warning(f"Request {index} finished with reason '{candidate.finish_reason.name}'.")
# Check for function call (structured output)
if candidate.content.parts and candidate.content.parts[0].function_call:
function_call = candidate.content.parts[0].function_call
# Convert the structured response to a JSON string
result_data = {key: val for key, val in function_call.args.items()}
return index, json.dumps(result_data, indent=2)
# Otherwise, return the plain text response
return index, response.text
except (google_exceptions.ResourceExhausted, google_exceptions.ServiceUnavailable, google_exceptions.InternalServerError) as e:
self.logger.warning(
f"API rate limit or server error for request {index} (Attempt {attempt + 1}/{self.max_retries}): {e}"
)
if attempt == self.max_retries - 1:
self.logger.error(f"Request {index} failed after {self.max_retries} retries.")
return index, f"Error: API request failed after multiple retries. Details: {e}"
time.sleep(2 ** attempt)
except Exception as e:
self.logger.error(f"An unexpected error occurred for request {index} (Attempt {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
return index, f"Error: An unexpected error occurred. Details: {e}"
time.sleep(2 ** attempt)
return index, None
def generate_from_input(
self,
user_inputs: List[str],
system_prompt: str = "",
response_schema: Optional[Union[type[BaseModel], dict]] = None,
use_batch: Optional[bool] = None,
batch_wait: Optional[bool] = None,
batch_dataset: Optional[str] = None,
csv_filename: Optional[str] = None,
bq_csv_filename: Optional[str] = None,
) -> Union[List[str], str]:
"""
Generates responses for a list of user inputs.
Args:
user_inputs: List of user input strings to process.
system_prompt: System prompt for the model.
response_schema: Optional Pydantic BaseModel or dict for structured output.
use_batch: If True, use batch processing via BigQuery. If False, use parallel real-time generation.
batch_wait: If True (and use_batch=True), wait for batch job to complete and return results.
If False, return the batch job name immediately for later retrieval.
batch_dataset: BigQuery dataset name for batch processing (default: "dataflow_batch").
csv_filename: Optional CSV filename for batch processing. If None, defaults to "batch_{timestamp}_{batch_id}.csv".
Returns:
- If use_batch=False: List of generated responses (same length as user_inputs).
- If use_batch=True and batch_wait=True: List of generated responses from batch job.
- If use_batch=True and batch_wait=False: Batch job resource name (str) for later retrieval.
"""
if use_batch is None:
use_batch = self.use_batch
if batch_wait is None:
batch_wait = self.batch_wait
if batch_dataset is None:
batch_dataset = self.batch_dataset
if csv_filename is None:
csv_filename = self.csv_filename
if bq_csv_filename is None:
bq_csv_filename = self.bq_csv_filename
if use_batch:
return self._generate_with_batch(
user_inputs=user_inputs,
system_prompt=system_prompt,
response_schema=response_schema,
wait_for_completion=batch_wait,
dataset_name=batch_dataset,
csv_filename=csv_filename,
bq_csv_filename=bq_csv_filename
)
else:
return self._generate_with_parallel(
user_inputs=user_inputs,
system_prompt=system_prompt,
response_schema=response_schema
)
def _generate_with_parallel(
self,
user_inputs: List[str],
system_prompt: str,
response_schema: Optional[Union[type[BaseModel], dict]]
) -> List[str]:
"""Internal method: Generates responses using parallel real-time API calls."""
responses = [None] * len(user_inputs)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_index = {
executor.submit(self._generate_single_with_retry, i, user_input, system_prompt, response_schema): i
for i, user_input in enumerate(user_inputs)
}
progress = tqdm(as_completed(future_to_index), total=len(user_inputs), desc="Generating with Gemini (Real-time)")
for future in progress:
index, result = future.result()
responses[index] = result
return responses
def _generate_with_batch(
self,
user_inputs: List[str],
system_prompt: str,
response_schema: Optional[Union[type[BaseModel], dict]],
wait_for_completion: bool,
dataset_name: str,
csv_filename: Optional[str] = None,
bq_csv_filename: Optional[str] = None
) -> Union[List[str], str]:
"""Internal method: Generates responses using batch processing via BigQuery."""
if not self.bq_client:
raise RuntimeError(
"BigQuery client is not initialized. Cannot use batch processing. "
"Please ensure GOOGLE_APPLICATION_CREDENTIALS is set correctly."
)
# Generate CSV filename if not provided
timestamp = int(time.time())
if csv_filename is None:
if bq_csv_filename is None:
batch_id = str(uuid.uuid4())[:8]
bq_csv_filename = f"batch_{timestamp}_{batch_id}.csv"
else:
# Even if user provides a bq_csv_filename, must prefix with timestamp
bq_csv_filename = f"{bq_csv_filename}_{timestamp}"
temp_csv_path = os.path.join(tempfile.gettempdir(), bq_csv_filename)
else:
# Always prefix provided csv_filename with timestamp when copying to temp
base_csv_name = Path(csv_filename).name
bq_csv_filename = f"{base_csv_name}_{timestamp}" if bq_csv_filename is None else f"{bq_csv_filename}_{timestamp}"
temp_csv_path = os.path.join(tempfile.gettempdir(), bq_csv_filename)
try:
# Step 1: Generate CSV for batch processing
self.logger.info(f"Batch mode: Generating CSV with {len(user_inputs)} inputs...")
temp_csv_path = os.path.join(tempfile.gettempdir(), csv_filename)
self.generate_bq_csv(
csv_filename=temp_csv_path,
system_prompt=system_prompt,
user_prompts=user_inputs,
response_schema=response_schema,
max_token=self.max_tokens
)
# Step 2: Upload to BigQuery
self.logger.info("Batch mode: Uploading to BigQuery...")
bq_uri = self.create_bq_table(temp_csv_path, dataset_name=dataset_name)
# Step 3: Submit batch prediction job
self.logger.info("Batch mode: Submitting batch prediction job...")
batch_job_name = self.run_batch_prediction(bq_uri, model=self.model_name)
# Clean up temporary CSV file
if os.path.exists(temp_csv_path):
os.remove(temp_csv_path)
self.logger.info(f"Cleaned up temporary CSV: {temp_csv_path}")
# Step 4: Wait for completion if requested
if wait_for_completion:
self.logger.info("Batch mode: Waiting for batch job to complete...")
results = self._wait_and_retrieve_batch_results(batch_job_name, len(user_inputs))
return results
else:
self.logger.info(f"Batch job submitted: {batch_job_name}. Use retrieve_batch_results() to get results later.")
return batch_job_name
except Exception as e:
# Clean up temporary CSV file on error
if 'temp_csv_path' in locals() and os.path.exists(temp_csv_path):
os.remove(temp_csv_path)
self.logger.error(f"Batch processing failed: {e}")
raise
def _wait_and_retrieve_batch_results(self, batch_job_name: str, expected_count: int) -> List[str]:
"""
Internal method: Waits for a batch job to complete and retrieves results.
Args:
batch_job_name: The resource name of the batch job.
expected_count: Expected number of results.
Returns:
List of generated responses.
"""
if not self.genai_client:
raise RuntimeError("Google GenAI client is not initialized. Cannot retrieve batch results.")
try:
# Get the batch job
batch_job = self.genai_client.batches.get(name=batch_job_name)
# Wait for completion with progress bar
self.logger.info("Waiting for batch job to complete (this may take several minutes)...")
# Poll for completion
max_wait_time = 3600 * 1000 # 1000 hours max
poll_interval = 30 # Check every 30 seconds
elapsed_time = 0
with tqdm(total=expected_count, desc="Batch job progress") as pbar:
while elapsed_time < max_wait_time:
batch_job = self.genai_client.batches.get(name=batch_job_name)
state = batch_job.state
if state == "JOB_STATE_SUCCEEDED":
pbar.update(100 - pbar.n)
self.logger.info("Batch job completed successfully!")
break
elif state in ["JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
raise RuntimeError(f"Batch job failed with state: {state}")
# Update progress bar (estimate based on time)
progress = min(90, int((elapsed_time / max_wait_time) * 100))
pbar.update(progress - pbar.n)
time.sleep(poll_interval)
elapsed_time += poll_interval
if elapsed_time >= max_wait_time:
raise TimeoutError(f"Batch job did not complete within {max_wait_time} seconds")
output_table = batch_job.dest.bigquery_uri.replace("bq://", "")
project, dataset, table = output_table.split(".")
table_id = f"{project}.{dataset}.{table}"
query = f"SELECT * FROM `{table_id}`"
df = self.bq_client.query(query).to_dataframe()
results = self._parse_batch_results(df, expected_count)
return results
except Exception as e:
self.logger.error(f"Failed to retrieve batch results: {e}")
raise
def _parse_batch_results(self, df: pd.DataFrame, expected_count: int) -> List[str]:
"""
Internal method: Parses batch results from DataFrame.
Args:
df: DataFrame containing batch results.
expected_count: Expected number of results.
Returns:
List of generated responses in original order.
"""
results = [None] * expected_count
for _, row in df.iterrows():
try:
# Get the index from the response
if 'index' in df.columns:
idx = int(row['index'])
else:
# Fallback to row number
idx = _
# Extract the response text
if 'response' in df.columns:
response_json = json.loads(row['response'])
# Navigate through the response structure
if 'candidates' in response_json and len(response_json['candidates']) > 0:
candidate = response_json['candidates'][0]
if 'content' in candidate and 'parts' in candidate['content']:
parts = candidate['content']['parts']
if parts and 'text' in parts[0]:
results[idx] = parts[0]['text']
elif parts and 'functionCall' in parts[0]:
# Structured output
results[idx] = json.dumps(parts[0]['functionCall']['args'])
# Fallback: if we couldn't parse, store the raw response
if results[idx] is None and 'response' in df.columns:
results[idx] = row['response']
except Exception as e:
self.logger.warning(f"Failed to parse result at index {idx}: {e}")
results[idx] = f"Error: Failed to parse result"
return results
def retrieve_batch_results(self, batch_job_name: str, expected_count: int) -> List[str]:
"""
Retrieves results from a previously submitted batch job.
Args:
batch_job_name: The resource name of the batch job (returned when use_batch=True and batch_wait=False).
expected_count: Expected number of results.
Returns:
List of generated responses.
Example:
# Submit batch job without waiting
job_name = serving.generate_from_input(inputs, system_prompt, use_batch=True, batch_wait=False)
# Later, retrieve results
results = serving.retrieve_batch_results(job_name, len(inputs))
"""
return self._wait_and_retrieve_batch_results(batch_job_name, expected_count)
# --- Batch Processing Methods ---
def create_bq_dataset(self, dataset_name: str = "polymer") -> None:
"""
Creates a BigQuery dataset if it does not already exist.
Args:
dataset_name: The name of the dataset to create. Defaults to "polymer".
Raises:
RuntimeError: If BigQuery client is not initialized.
"""
if not self.bq_client:
raise RuntimeError("BigQuery client is not initialized. Cannot create dataset.")
dataset_ref = self.bq_client.dataset(dataset_name)
dataset = bigquery.Dataset(dataset_ref)
dataset.location = self.location
try:
self.bq_client.create_dataset(dataset)
self.logger.info(f"Dataset '{dataset_name}' created successfully.")
except Exception as e:
if "Already Exists" in str(e):
self.logger.info(f"Dataset '{dataset_name}' already exists.")
else:
self.logger.error(f"Failed to create dataset '{dataset_name}': {e}")
raise
def generate_bq_csv(
self,
csv_filename: str,
system_prompt: str,
user_prompts: List[str],
doi_list: Optional[List[str]] = None,
response_schema: Optional[Union[type[BaseModel], dict]] = None,
max_token: int = 500
) -> str:
"""
Generates a CSV file for batch processing with the Gemini API.
Args:
csv_filename: The name of the output CSV file.
system_prompt: The system prompt string.
user_prompts: A list of texts to be processed.
doi_list: Optional list of DOIs corresponding to user prompts.
response_schema: An optional Pydantic BaseModel class or dict for the response schema.
max_token: Maximum number of output tokens. Defaults to 500.
Returns:
The name of the generated CSV file.
"""
df = pd.DataFrame({"user_prompt": user_prompts})
def create_batch_request_json(row) -> str:
request_parts = [
{"text": row["user_prompt"]},
]
generation_config = {
"temperature": self.temperature,
"maxOutputTokens": max_token,
"stopSequences": ["\n\n\n\n"],
"responseLogprobs": True,
"logprobs": 10,
}
if response_schema:
generation_config["responseMimeType"] = "application/json"
# Handle both BaseModel and dict schemas
if isinstance(response_schema, dict):
generation_config["responseSchema"] = response_schema
else:
generation_config["responseSchema"] = response_schema.model_json_schema()
else:
generation_config["responseMimeType"] = "text/plain"
return json.dumps(
{
"contents": [
{
"role": "user",
"parts": request_parts,
},
],
"systemInstruction": {
"parts": [{"text": system_prompt}]
},
"generationConfig": generation_config,
}
)
df["request"] = df.apply(create_batch_request_json, axis=1)
if doi_list:
df["doi"] = doi_list
df.to_csv(csv_filename, index_label="index", columns=["doi", "request"])
else:
df.to_csv(csv_filename, index_label="index", columns=["request"])
self.logger.info(f"Generated batch CSV file: {csv_filename}")
return csv_filename
def create_bq_table(self, csv_path: str, dataset_name: str = "polymer") -> str:
"""
Creates a BigQuery table from a CSV file.
Args:
csv_path: The path to the CSV file.
dataset_name: The name of the dataset to create the table in. Defaults to "polymer".
Returns:
The BigQuery URI of the created table (e.g., "bq://project.dataset.table").
Raises:
RuntimeError: If BigQuery client is not initialized.
"""
if not self.bq_client:
raise RuntimeError("BigQuery client is not initialized. Cannot create table.")
self.create_bq_dataset(dataset_name)
timestamp = int(time.time())
table_name = f"{timestamp}_{Path(csv_path).stem}"
table_id = f"{dataset_name}.{table_name}"
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.CSV,
autodetect=True,
skip_leading_rows=1,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
)
with open(csv_path, "rb") as source_file:
job = self.bq_client.load_table_from_file(source_file, table_id, job_config=job_config)
job.result()
self.logger.info(f"Loaded {job.output_rows} rows into {table_id}")
return f"bq://{self.bq_client.project}.{dataset_name}.{table_name}"
def run_batch_prediction(
self,
input_bq_uri: str,
model: str = None,
output_file_path: str = ""
) -> str:
"""
Runs a batch prediction job using the Gemini API.
Args:
input_bq_uri: The BigQuery URI of the input table (e.g., "bq://project.dataset.table")
or a path to a CSV file. If a CSV path is provided, a BigQuery table
will be created from it.
model: The ID of the model to use for prediction. Defaults to the instance's model_name.
output_file_path: Optional. The desired path for the output file. If empty, defaults to
the input path with "_result" appended to the table name.
Returns:
The name/resource path of the created batch prediction job.
Raises:
RuntimeError: If BigQuery client or GenAI client is not initialized.
"""
if not self.bq_client:
raise RuntimeError("BigQuery client is not initialized. Cannot run batch prediction.")
if not self.genai_client:
raise RuntimeError("Google GenAI client is not initialized. Cannot run batch prediction.")
# Check if input_bq_uri is a CSV file path
if input_bq_uri.endswith(".csv") and Path(input_bq_uri).is_file():
self.logger.info(f"CSV file detected: {input_bq_uri}. Creating BigQuery table...")
input_bq_uri = self.create_bq_table(input_bq_uri)
self.logger.info(f"BigQuery table created from CSV: {input_bq_uri}")
# Parse input_bq_uri to get project, dataset, and table name
match = re.match(r"bq://([^.]+)\.([^.]+)\.(.+)", input_bq_uri)
if not match:
raise ValueError(
f"Invalid input_bq_uri format: {input_bq_uri}. "
"Must be a BigQuery URI (bq://project.dataset.table) or a valid CSV file path."
)
project_id = match.group(1)
dataset_name = match.group(2)
input_table_name = match.group(3)
# Construct the output BigQuery URI
if output_file_path:
output_table_name = Path(output_file_path).stem
else:
output_table_name = f"{input_table_name}_result"
output_uri = f"bq://{project_id}.{dataset_name}.{output_table_name}"
# Use the model name from instance or provided parameter
model_name = model or self.model_name
try:
# Create batch prediction job using Google GenAI API
batch_job = self.genai_client.batches.create(
model=model_name,
src=input_bq_uri,
config=CreateBatchJobConfig(dest=output_uri)
)
self.logger.info(f"Batch job '{batch_job.name}' created. State: {batch_job.state}")
return batch_job.name
except Exception as e:
self.logger.error(f"Failed to create batch prediction job: {e}")
raise
def download_bq_table(self, table_id: str, output_file_path: str = "") -> str:
"""
Downloads data from a BigQuery table to a CSV file.
Args:
table_id: The full BigQuery table ID (e.g., "project.dataset.table").
output_file_path: Optional. The path to save the downloaded CSV file.
If empty, defaults to the table name with a .csv extension.
Returns:
The path to the downloaded CSV file.
Raises:
RuntimeError: If BigQuery client is not initialized.
"""
if not self.bq_client:
raise RuntimeError("BigQuery client is not initialized. Cannot download table.")
# Parse table_id to get project, dataset, and table name
parts = table_id.split('.')
if len(parts) != 3:
raise ValueError(
f"Invalid table_id format: {table_id}. Expected 'project.dataset.table'."
)
table_name = parts[2]
if not output_file_path:
output_file_path = f"{table_name}.csv"
try:
table = self.bq_client.get_table(table_id)
rows = self.bq_client.list_rows(table)
df = rows.to_dataframe()
df.to_csv(output_file_path, index=False)
self.logger.info(f"Downloaded {rows.total_rows} rows from {table_id} to {output_file_path}")
return output_file_path
except Exception as e:
self.logger.error(f"Error downloading BigQuery table {table_id}: {e}")
raise
# --- Example Usage (Modified) ---
if __name__ == "__main__":
# IMPORTANT: Before running, ensure you have authenticated with Google Cloud.
# This is typically done by running `gcloud auth application-default login` in your terminal.
# --- Step 1: Check for Authentication Credentials ---
# This is a hard requirement for the API to work.
# export GOOGLE_APPLICATION_CREDENTIALS="path/to/your/private_project.json"
if not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
print("=" * 80)
print("FATAL: GOOGLE_APPLICATION_CREDENTIALS environment variable is not set.")
print("This is required for Vertex AI authentication.")
print("Please authenticate with Google Cloud first by running:")
print("`gcloud auth application-default login`")
print("=" * 80)
exit(1) # Exit if authentication is not configured.
# --- Step 2: Determine Project ID (Optional) ---
# The SDK can often auto-discover the project ID if the environment variable is not set.
gcp_project_id = os.getenv("GCP_PROJECT_ID")
if not gcp_project_id:
print("=" * 80)
print("INFO: GCP_PROJECT_ID environment variable is not set.")
print("The Vertex AI SDK will attempt to automatically discover the project ID from your environment.")
print("Ensure you have set a default project via `gcloud config set project YOUR_PROJECT_ID`.")
print("=" * 80)
# --- Step 3: Run Generation Tests ---
try:
# --- Test Case 1: Normal Text Generation ---
print("\n--- Starting Test 1: Normal Text Generation ---")
gemini_server_text = APIGoogleVertexAIServing(
project=gcp_project_id, # Pass the project_id (can be None)
location='us-central1',
model_name="gemini-2.5-flash",
max_workers=5,
max_retries=3
)
system_prompt_text = "You are a helpful assistant that provides concise and accurate answers."
user_prompts_text = [
"What is the capital of France?",
"Write a short poem about the moon.",
"Explain the concept of photosynthesis in one sentence.",
]
results_text = gemini_server_text.generate_from_input(user_prompts_text, system_prompt_text)
print("--- Generation Complete ---")
for i, (prompt, result) in enumerate(zip(user_prompts_text, results_text)):
print(f"\n[Prompt {i+1}]: {prompt}")
print(f"[Gemini]: {result}")
# --- Test Case 2: Structured Data Extraction (PyDantic) ---
print("\n--- Starting Test 2: Structured Data Extraction (JSON Output) ---")
class UserDetails(BaseModel):
name: str
age: int
city: str
gemini_server_json =APIGoogleVertexAIServing(
project=gcp_project_id, # Pass the project_id (can be None)
location='us-central1',
model_name="gemini-2.5-flash",
)
system_prompt_json = "Extract the user's information from the text and format it as JSON."
user_prompts_json = [
"John Doe is 30 years old and lives in New York.",
"My name is Jane Smith, I am 25, and I reside in London."
]
results_json = gemini_server_json.generate_from_input(user_prompts_json, system_prompt_json, response_schema=UserDetails) # Pass the schema here
print("--- Generation Complete ---")
for i, (prompt, result) in enumerate(zip(user_prompts_json, results_json)):
print(f"\n[Prompt {i+1}]: {prompt}")
print(f"[Gemini JSON]: {result}")
# --- Test Case 3: Structured Data Extraction (Raw JSON Schema) ---
print("\n--- Starting Test 3: Structured Data Extraction (Raw JSON Schema) ---")
json_schema = {
"title": "UserDetails",
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"city": {"type": "string"}
},
"required": ["name", "age", "city"]
}
gemini_server_json_schema = APIGoogleVertexAIServing(
project=gcp_project_id, # Pass the project_id (can be None)
location='us-central1',
model_name="gemini-2.5-flash",
)
system_prompt_json_schema = "Extract the user's information from the text and format it as JSON."
user_prompts_json_schema = [
"Alice Johnson is 28 years old and lives in San Francisco.",
"Bob Brown, aged 35, resides in Toronto."
]
results_json_schema = gemini_server_json_schema.generate_from_input(user_prompts_json_schema, system_prompt_json_schema, response_schema=json_schema)
print("--- Generation Complete ---")
for i, (prompt, result) in enumerate(zip(user_prompts_json_schema, results_json_schema)):
print(f"\n[Prompt {i+1}]: {prompt}")
print(f"[Gemini JSON Schema]: {result}")
# --- Test Case 4: Batch Processing with BigQuery (Submit without waiting) ---
print("\n--- Starting Test 4: Batch Processing (Submit without waiting) ---")
from typing import Literal
class Capital(BaseModel):
capital: Literal["Paris", "Beijing", "London", "Tokyo"]
gemini_server_batch = APIGoogleVertexAIServing(
project=os.getenv("GCP_PROJECT_ID"),
location='us-central1',
model_name="gemini-2.5-flash",
)
system_prompt_batch = "You are a helpful assistant that answers geography questions."
user_prompts_batch = [
"What is the capital of France?",
"What is the capital of China?",
"What is the capital of the United Kingdom?",
"What is the capital of Japan?"
]
try:
# Submit batch job without waiting (batch_wait=False)
batch_job_name = gemini_server_batch.generate_from_input(
user_inputs=user_prompts_batch,
system_prompt=system_prompt_batch,
response_schema=Capital,
use_batch=True,
batch_wait=False # Don't wait for completion
)
print(f"Batch job submitted: {batch_job_name}")
print("Note: Use retrieve_batch_results(batch_job_name, len(inputs)) to get results later.")
except Exception as e:
print(f"Batch processing test (no wait) failed: {e}")
# --- Test Case 5: Batch Processing with BigQuery (Wait for completion) ---
print("\n--- Starting Test 5: Batch Processing (Wait for completion) ---")
print("WARNING: This test will wait for batch job completion, which may take several minutes.")
print("Skipping this test in automated runs. Set ENABLE_BATCH_WAIT_TEST=1 to enable.")
try:
# Submit batch job and wait for completion (batch_wait=True, default)
results_batch = gemini_server_batch.generate_from_input(
user_inputs=user_prompts_batch,
system_prompt=system_prompt_batch,
response_schema=Capital,
use_batch=True,
batch_wait=True # Wait for completion
)
print("--- Batch Generation Complete ---")
for i, (prompt, result) in enumerate(zip(user_prompts_batch, results_batch)):
print(f"\n[Prompt {i+1}]: {prompt}")
print(f"[Gemini Batch]: {result}")
except Exception as e:
print(f"Batch processing test (with wait) failed: {e}")
except google_exceptions.PermissionDenied as e:
print(f"\nERROR: Permission Denied. Details: {e}")
print("Please ensure your account has the 'Vertex AI User' role on the project.")
print("Also, verify that the Vertex AI API is enabled for your project.")
except google_exceptions.NotFound as e:
print(f"\nERROR: Not Found. Details: {e}")
print("This might mean the project ID could not be found or the specified model/location is incorrect.")
except Exception as e:
print(f"\nAn unexpected error occurred: {e}")
import json
import requests
import os
import logging
from ..logger import get_logger
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from dataflow.core import LLMServingABC
import re
import time
class APILLMServing_request(LLMServingABC):
"""Use OpenAI API to generate responses based on input messages.
"""
def start_serving(self) -> None:
self.logger.info("APILLMServing_request: no local service to start.")
return
def __init__(self,
api_url: str = "https://api.openai.com/v1/chat/completions",
key_name_of_api_key: str = "DF_API_KEY",
model_name: str = "gpt-4o",
max_workers: int = 10,
max_retries: int = 5,
temperature = 0.0
):
# Get API key from environment variable or config
self.api_url = api_url
self.model_name = model_name
self.max_workers = max_workers
self.max_retries = max_retries
self.temperature = temperature
self.logger = get_logger()
# config api_key in os.environ global, since safty issue.
self.api_key = os.environ.get(key_name_of_api_key)
if self.api_key is None:
error_msg = f"Lack of `{key_name_of_api_key}` in environment variables. Please set `{key_name_of_api_key}` as your api-key to {api_url} before using APILLMServing_request."
self.logger.error(error_msg)
raise ValueError(error_msg)
def format_response(self, response: dict, is_embedding: bool = False) -> str:
"""Format API response, supporting both embedding and chat completion modes"""
# Handle embedding requests
if is_embedding:
return response.get('data', [{}])[0].get('embedding', [])
# Extract message content
message = response.get('choices', [{}])[0].get('message', {})
content = message.get('content', '')
# Return directly if content is already in think/answer format
if re.search(r'<think>.*?</think>.*?<answer>.*?</answer>', content, re.DOTALL):
return content
# Check for reasoning_content
reasoning_content = message.get('reasoning_content')
# Wrap with think/answer tags if reasoning_content exists and is not empty
if reasoning_content:
return f"<think>{reasoning_content}</think>\n<answer>{content}</answer>"
return content
def api_chat(self, system_info: str, messages: str, model: str):
try:
payload = json.dumps({
"model": model,
"messages": [
{"role": "system", "content": system_info},
{"role": "user", "content": messages}
],
"temperature": self.temperature
})
headers = {
'Authorization': f"Bearer {self.api_key}",
'Content-Type': 'application/json',
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)'
}
# Make a POST request to the API
response = requests.post(self.api_url, headers=headers, data=payload, timeout=60)
if response.status_code == 200:
response_data = response.json()
return self.format_response(response_data)
else:
logging.error(f"API request failed with status {response.status_code}: {response.text}")
return None
except Exception as e:
logging.error(f"API request error: {e}")
return None
def _api_chat_with_id(self, id, payload, model, is_embedding: bool = False, json_schema: dict = None):
try:
if is_embedding:
payload = json.dumps({
"model": model,
"input": payload
})
elif json_schema is None:
payload = json.dumps({
"model": model,
"messages": payload
})
else:
payload = json.dumps({
"model": model,
"messages": payload,
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "custom_response",
"strict": True,
"schema": json_schema
}
}
})
headers = {
'Authorization': f"Bearer {self.api_key}",
'Content-Type': 'application/json',
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)'
}
# Make a POST request to the API
response = requests.post(self.api_url, headers=headers, data=payload, timeout=1800)
if response.status_code == 200:
# logging.info(f"API request successful")
response_data = response.json()
# logging.info(f"API response: {response_data['choices'][0]['message']['content']}")
return id,self.format_response(response_data, is_embedding)
else:
logging.error(f"API request failed with status {response.status_code}: {response.text}")
return id, None
except Exception as e:
logging.error(f"API request error: {e}")
return id, None
def _api_chat_id_retry(self, id, payload, model, is_embedding : bool = False, json_schema: dict = None):
for i in range(self.max_retries):
id, response = self._api_chat_with_id(id, payload, model, is_embedding, json_schema)
if response is not None:
return id, response
time.sleep(2**i)
return id, None
def generate_from_input(self,
user_inputs: list[str],
system_prompt: str = "You are a helpful assistant",
json_schema: dict = None,
) -> list[str]:
responses = [None] * len(user_inputs)
# -- end of subfunction api_chat_with_id --
# 使用 ThreadPoolExecutor 并行处理多个问题
# logging.info(f"Generating {len(questions)} responses")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(
self._api_chat_id_retry,
payload = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
],
model = self.model_name,
json_schema = json_schema,
id = idx,
) for idx, question in enumerate(user_inputs)
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating......"):
response = future.result() # (id, response)
responses[response[0]] = response[1]
return responses
def generate_from_conversations(self, conversations: list[list[dict]]) -> list[str]:
responses = [None] * len(conversations)
# -- end of subfunction api_chat_with_id --
# 使用 ThreadPoolExecutor 并行处理多个问题
# logging.info(f"Generating {len(questions)} responses")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(
self._api_chat_id_retry,
payload = dialogue,
model = self.model_name,
id = idx
) for idx, dialogue in enumerate(conversations)
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating......"):
response = future.result() # (id, response)
responses[response[0]] = response[1]
return responses
def generate_embedding_from_input(self, texts: list[str]) -> list[list[float]]:
responses = [None] * len(texts)
# -- end of subfunction api_embedding_with_id --
# 使用 ThreadPoolExecutor 并行处理多个问题
# logging.info(f"Generating {len(questions)} responses")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(
self._api_chat_id_retry,
payload = txt,
model = self.model_name,
id = idx,
is_embedding = True
) for idx, txt in enumerate(texts)
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating embedding......"):
response = future.result() # (id, response)
responses[response[0]] = response[1]
return responses
def cleanup(self):
# Cleanup resources if needed
logging.info("Cleaning up resources in APILLMServing_request")
# No specific cleanup actions needed for this implementation
pass
import os
import base64
import json
from typing import Any, Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataflow.core import LLMServingABC
from openai import OpenAI
from tqdm import tqdm
from ..logger import get_logger
class APIVLMServing_openai(LLMServingABC):
"""
Client for interacting with a Vision-Language Model (VLM) via OpenAI's API.
Provides methods for single-image chat, batch image processing, and multi-image analysis,
with support for concurrent requests.
"""
def start_serving(self) -> None:
self.logger.info("APIVLMServing_opneai: no local service to start.")
return
def __init__(
self,
api_url: str = "https://api.openai.com/v1",
key_name_of_api_key: str = "DF_API_KEY",
model_name: str = "o4-mini",
max_workers: int = 10,
timeout: int = 1800,
temperature = 0.0
):
"""
Initialize the OpenAI client and settings.
:param api_url: Base URL of the VLM API endpoint.
:param key_name_of_api_key: Environment variable name for the API key.
:param model_name: Default model name to use for requests.
:param max_workers: Maximum number of threads for concurrent requests.
"""
self.api_url = api_url
self.model_name = model_name
self.max_workers = max_workers
self.logger = get_logger()
self.timeout = timeout
self.temperature = temperature
api_key = os.environ.get(key_name_of_api_key)
if not api_key:
self.logger.error(f"API key not found in environment variable '{key_name_of_api_key}'")
raise EnvironmentError(f"Missing environment variable '{key_name_of_api_key}'")
self.client = OpenAI(
api_key=api_key,
base_url=api_url
)
def _encode_image_to_base64(self, image_path: str) -> Tuple[str, str]:
"""
Read an image file and convert it to a base64-encoded string, returning the image data and MIME format.
:param image_path: Path to the image file.
:return: Tuple of (base64-encoded string, image format, e.g. 'jpeg' or 'png').
:raises ValueError: If the image format is unsupported.
"""
with open(image_path, "rb") as f:
raw = f.read()
b64 = base64.b64encode(raw).decode("utf-8")
ext = image_path.rsplit('.', 1)[-1].lower()
if ext == 'jpg':
fmt = 'jpeg'
elif ext == 'jpeg':
fmt = 'jpeg'
elif ext == 'png':
fmt = 'png'
else:
raise ValueError(f"Unsupported image format: {ext}")
return b64, fmt
def _create_messages(self, content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Wrap content items into the standard OpenAI messages structure.
:param content: List of content dicts (text/image elements).
:return: Messages payload for the API call.
"""
return [{"role": "user", "content": content}]
def _send_chat_request(
self,
model: str,
messages: List[Dict[str, Any]],
timeout: int,
json_schema: dict = None
) -> str:
"""
Send a chat completion request to the OpenAI API and return the generated content.
:param model: Model name for the request.
:param messages: Messages payload constructed by `_create_messages`.
:param timeout: Timeout in seconds for the API call.
:param json_schema: Optional JSON schema for structured output.
:return: Generated text response from the model.
"""
# 准备请求参数
request_params = {
"model": model,
"messages": messages,
"timeout": timeout,
"temperature": self.temperature
}
# 如果提供了 JSON schema,添加 response_format
if json_schema is not None:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "chart_info_response",
"strict": True,
"schema": json_schema
}
}
resp = self.client.chat.completions.create(**request_params)
return resp.choices[0].message.content
def chat_with_one_image(
self,
image_path: str,
text_prompt: str,
model: str = None,
timeout: int = 1800,
json_schema: dict = None
) -> str:
"""
Perform a chat completion using a single image and a text prompt.
:param image_path: Path to the image file.
:param text_prompt: Text prompt to accompany the image.
:param model: (Optional) Model override; defaults to instance `model_name`.
:param timeout: Timeout in seconds for the API call.
:param json_schema: (Optional) JSON schema for structured output.
:return: Model's response as a string.
"""
model = model or self.model_name
b64, fmt = self._encode_image_to_base64(image_path)
content = [
{"type": "text", "text": text_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/{fmt};base64,{b64}"}}
]
messages = self._create_messages(content)
return self._send_chat_request(model, messages, timeout, json_schema)
def chat_with_one_image_with_id(
self,
request_id: Any,
image_path: str,
text_prompt: str,
model: str = None,
timeout: int = 1800,
json_schema: dict = None,
) -> Tuple[Any, str]:
"""
Same as `chat_with_one_image` but returns a tuple of (request_id, response).
:param request_id: Arbitrary identifier for tracking the request.
:param image_path: Path to the image file.
:param text_prompt: Text prompt to accompany the image.
:param model: (Optional) Model override; defaults to instance `model_name`.
:param timeout: Timeout in seconds for the API call.
:return: Tuple of (request_id, model response).
"""
response = self.chat_with_one_image(image_path, text_prompt, model, timeout, json_schema)
return request_id, response
def generate_from_input_one_image(
self,
image_paths: List[str],
text_prompts: List[str],
system_prompt: str = "",
model: str = None,
timeout: int = 1800,
json_schema: dict = None
) -> List[str]:
"""
Batch process single-image chat requests concurrently.
:param image_paths: List of image file paths.
:param text_prompts: List of text prompts (must match length of image_paths).
:param system_prompt: Optional system-level prompt prefixed to each user prompt.
:param model: (Optional) Model override; defaults to instance `model_name`.
:param timeout: Timeout in seconds for each API call.
:param json_schema: (Optional) JSON schema for structured output.
:return: List of model responses preserving input order.
:raises ValueError: If lengths of image_paths and text_prompts differ.
"""
if len(image_paths) != len(text_prompts):
raise ValueError("`image_paths` and `text_prompts` must have the same length")
model = model or self.model_name
prompts = [f"{system_prompt}\n{p}" for p in text_prompts]
responses = [None] * len(image_paths)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(
self.chat_with_one_image_with_id,
idx,
path,
prompt,
model,
timeout,
json_schema
): idx
for idx, (path, prompt) in enumerate(zip(image_paths, prompts))
}
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating..."):
idx, res = future.result()
responses[idx] = res
return responses
def analyze_images_with_gpt(
self,
image_paths: List[str],
image_labels: List[str],
system_prompt: str = "",
model: str = None,
timeout: int = 1800,
json_schema: dict = None
) -> str:
"""
Analyze multiple images in a single request with labels.
:param image_paths: List of image file paths.
:param image_labels: Corresponding labels for each image.
:param system_prompt: Overall prompt before listing images.
:param model: (Optional) Model override; defaults to instance `model_name`.
:param timeout: Timeout in seconds for the API call.
:return: Model's combined analysis as text.
"""
if len(image_paths) != len(image_labels):
raise ValueError("`image_paths` and `image_labels` must have the same length")
model = model or self.model_name
content: List[Dict[str, Any]] = []
if system_prompt:
content.append({"type": "text", "text": system_prompt})
for label, path in zip(image_labels, image_paths):
b64, fmt = self._encode_image_to_base64(path)
content.append({"type": "text", "text": f"{label}:"})
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/{fmt};base64,{b64}"}
})
messages = self._create_messages(content)
return self._send_chat_request(model, messages, timeout, json_schema)
def analyze_images_with_gpt_with_id(
self,
image_paths: List[str],
image_labels: List[str],
request_id: Any,
system_prompt: str = "",
model: str = None,
timeout: int = 1800,
json_schema: dict = None
) -> Tuple[Any, str]:
"""
Batch-tracked version of `analyze_images_with_gpt`, returning (request_id, analysis).
:param image_paths: List of image file paths.
:param image_labels: Corresponding labels for each image.
:param request_id: Identifier for tracking the request.
:param system_prompt: Overall prompt before listing images.
:param model: (Optional) Model override; defaults to instance `model_name`.
:param timeout: Timeout in seconds for the API call.
:return: Tuple of (request_id, model's analysis).
"""
result = self.analyze_images_with_gpt(
image_paths,
image_labels,
system_prompt,
model,
timeout,
json_schema
)
self.logger.info(f"Request {request_id} completed")
return request_id, result
def generate_from_input_multi_images(
self,
list_of_image_paths: List[List[str]],
list_of_image_labels: List[List[str]],
system_prompt: str = "",
user_prompts: List[str] = None,
model: str = None,
timeout: int = 1800,
json_schema: dict = None
) -> List[str]:
"""
Concurrently analyze multiple sets of images with labels.
:param list_of_image_paths: List of image path lists.
:param list_of_image_labels: Parallel list of label lists.
:param system_prompt: Prompt prefixed to each batch.
:param model: (Optional) Model override; defaults to instance `model_name`.
:param timeout: Timeout in seconds for each API call.
:return: List of analysis results in input order.
:raises ValueError: If outer lists lengths differ.
"""
if len(list_of_image_paths) != len(list_of_image_labels):
raise ValueError(
"`list_of_image_paths` and `list_of_image_labels` must have the same length"
)
model = model or self.model_name
responses = [None] * len(list_of_image_paths)
if user_prompts == None:
user_prompts = [""] * len(list_of_image_paths)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(
self.analyze_images_with_gpt_with_id,
paths,
labels,
idx,
system_prompt + user_prompt,
model,
timeout,
json_schema
): idx
for idx, (paths, labels, user_prompt) in enumerate(
zip(list_of_image_paths, list_of_image_labels, user_prompts)
)
}
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating..."):
idx, res = future.result()
responses[idx] = res
return responses
def cleanup(self) -> None:
"""
Clean up any resources (e.g., close HTTP connections).
"""
self.client.close()
def generate_from_input(self, user_inputs: List[str], system_prompt: str = "Describe the image in detail.", json_schema: dict = None):
"""
user_inputs: List[str], list of picture paths
system_prompt: str, system prompt
return: List[str], list of generated contents
"""
futures = []
result_text_list = [None] * len(user_inputs)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
for idx,user_input in enumerate(user_inputs):
futures.append(executor.submit(self.chat_with_one_image_with_id,
idx,
user_input,
system_prompt,
self.model_name,
self.timeout,
json_schema = json_schema,))
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
idx,res = future.result()
result_text_list[idx] = res
return result_text_list
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from googleapiclient import discovery
from dataflow.core import LLMServingABC
class PerspectiveAPIServing(LLMServingABC):
"""Service adapter for Google Perspective API."""
def __init__(self, max_workers: int = 10):
self.api_key = os.environ.get("GOOGLE_API_KEY")
if self.api_key is None:
raise ValueError("Lack of Google API_KEY")
self.max_workers = max_workers
self.client = discovery.build(
"commentanalyzer",
"v1alpha1",
developerKey=self.api_key,
discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
static_discovery=False,
)
def _call_api(self, text: str) -> float:
"""Invoke the Perspective API for a single text chunk and return toxicity score."""
analyze_request = {
'comment': { 'text': text },
'requestedAttributes': {'TOXICITY': {}}
}
response = self.client.comments().analyze(body=analyze_request).execute()
# extract the first span score
return response['attributeScores']['TOXICITY']['spanScores'][0]['score']['value']
def generate_from_input(self, user_inputs: list[str]) -> list[float]: # type: ignore
"""
Process a list of input texts concurrently and return toxicity scores.
"""
scores: dict[int, float] = {}
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {executor.submit(self._call_api, text): idx
for idx, text in enumerate(user_inputs)}
for fut in tqdm(as_completed(futures), total=len(futures), desc="Scoring"):
idx = futures[fut]
try:
scores[idx] = fut.result()
except Exception as e:
scores[idx] = float('nan')
return [scores[i] for i in range(len(user_inputs))]
def cleanup(self) -> None:
"""Cleanup any resources or open connections if necessary."""
# No persistent resources to clean up
return
\ No newline at end of file
import os
import time
import re
from typing import List, Optional, Dict, Any, Union, Tuple
from dataflow.core import LLMServingABC
from dataflow.logger import get_logger
import asyncio
from tqdm.asyncio import tqdm_asyncio
WORKING_DIR = "./LightRAG"
async def initialize_rag(
llm_model_name, api_url, api_key, embed_model_name, embed_binding_host, embedding_dim, max_embed_tokens):
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
llm_model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=api_key,
base_url=api_url,
**kwargs,
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=max_embed_tokens,
func=lambda texts: ollama_embed(
texts,
embed_model=embed_model_name,
host=embed_binding_host
),
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
class LightRAGServing(LLMServingABC):
def __init__(self,
api_url: str = "https://api.openai.com/v1",
key_name_of_api_key: str = "DF_API_KEY",
llm_model_name: str = "gpt-4o",
embed_model_name: str = "bge-m3:latest",
embed_binding_host: str = "http://localhost:11434",
embedding_dim: int = 1024,
max_embed_tokens: int = 8192,
document_list: List[str] = []
):
try:
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.ollama import ollama_embed
from lightrag.utils import EmbeddingFunc
from lightrag.kg.shared_storage import initialize_pipeline_status
except ImportError:
raise Exception(
"""
lightrag is not installed in this environment yet.
Please use pip install lightrag-hku.
"""
)
self.rag: LightRAG = None
self.api_url = api_url
self.llm_model_name = llm_model_name
self.embed_model_name = embed_model_name
self.embed_binding_host = embed_binding_host
self.embedding_dim = embedding_dim
self.max_embed_tokens = max_embed_tokens
self.logger = get_logger()
self.document_list = document_list
# config api_key in os.environ global, since safty issue.
self.api_key = os.environ.get(key_name_of_api_key)
if self.api_key is None:
error_msg = f"Lack of `{key_name_of_api_key}` in environment variables. Please set `{key_name_of_api_key}` as your api-key to {api_url} before using APILLMServing_request."
self.logger.error(error_msg)
raise ValueError(error_msg)
@classmethod
async def create(cls, *args, **kwargs) -> "LightRAGServing":
instance = cls(*args, **kwargs)
if instance.rag is None:
instance.rag = await initialize_rag(
instance.llm_model_name,
instance.api_url,
instance.api_key,
instance.embed_model_name,
instance.embed_binding_host,
instance.embedding_dim,
instance.max_embed_tokens
)
try:
instance.logger.info("Loading documents...")
await instance.load_documents(instance.document_list)
instance.logger.info("Documents processing completed.")
except Exception as e:
instance.logger.error(f"Error during documents processing: {e}\n")
return
return instance
def start_serving(self):
pass
async def cleanup(self):
if self.rag:
storages = [
self.rag.text_chunks,
self.rag.full_docs,
self.rag.entities_vdb,
self.rag.relationships_vdb,
self.rag.chunks_vdb,
self.rag.chunk_entity_relation_graph,
self.rag.doc_status,
self.rag.llm_response_cache
]
await asyncio.gather(*[s.drop() for s in storages], return_exceptions=True)
await self.rag.finalize_storages()
async def load_documents(self, document_paths: List[str]):
tasks = []
for path in document_paths:
with open(path, "r", encoding="utf-8") as f:
tasks.append(self.rag.ainsert(f.read()))
await asyncio.gather(*tasks, return_exceptions=True)
async def generate_from_input(self, user_inputs: List[str], system_prompt: str) -> List[str]:
tasks = [
self.rag.aquery(question, system_prompt=system_prompt, param=QueryParam(mode="hybrid"))
for question in user_inputs
]
responses = await tqdm_asyncio.gather(*tasks)
return list(responses)
\ No newline at end of file
import os
import time
import re
from typing import List, Optional, Dict, Any, Union, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from dataflow.core import LLMServingABC
from dataflow.logger import get_logger
class LiteLLMServing(LLMServingABC):
"""
LiteLLM-based serving class that provides unified interface for multiple LLM providers.
Supports OpenAI, Anthropic, Cohere, Azure, AWS Bedrock, Google and many more providers.
This implementation avoids global state pollution by passing configuration parameters
directly to each litellm.completion() call, ensuring thread safety and proper isolation
between different instances. Configuration parameters are immutable after initialization.
Doc: https://docs.litellm.ai/docs/providers
"""
def start_serving(self) -> None:
self.logger.info("LiteLLMServing: no local service to start.")
return
def start_serving(self) -> None:
self.logger.info("LiteLLMServing: no local service to start.")
return
def __init__(self,
api_url: str = "https://api.openai.com/v1/chat/completions",
key_name_of_api_key: str = "DF_API_KEY",
model_name: str = "gpt-4o",
max_workers: int = 10,
max_retries: int = 5,
api_version: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 1024,
top_p: float = 1.0,
timeout: int = 60,
custom_llm_provider: str = None,
**kwargs: Any):
"""
Initialize LiteLLM serving instance.
Args:
api_url: Custom API base URL
key_name_of_api_key: Environment variable name for API key (default: "DF_API_KEY")
model_name: Model name (e.g., "gpt-4o", "claude-3-sonnet", "command-r-plus")
max_workers: Number of concurrent workers for batch processing
max_retries: Number of LLM inference retry chances for each input
api_url: Custom API base URL
key_name_of_api_key: Environment variable name for API key (default: "DF_API_KEY")
model_name: Model name (e.g., "gpt-4o", "claude-3-sonnet", "command-r-plus")
max_workers: Number of concurrent workers for batch processing
max_retries: Number of LLM inference retry chances for each input
api_version: API version for providers that support it
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Top-p sampling parameter
timeout: Request timeout in seconds
custom_llm_provider:
Optional custom provider name registered in LiteLLM for routing requests to
non-default backends (e.g., self-hosted OpenAI-compatible APIs or private endpoints).
Example: `"my_local_llm"` or `"company-internal-provider"`.
**kwargs: Additional parameters passed to litellm.completion()
Note:
All configuration parameters are immutable after initialization.
If you need different settings, create a new instance.
"""
# Import litellm at initialization time to support lazy importing
try:
import litellm
self._litellm = litellm
except ImportError:
raise ImportError(
"litellm is not installed. Please install it with: "
"pip install open-dataflow[litellm] or pip install litellm"
)
self.model_name = model_name
self.api_url = api_url
self.api_version = api_version
self.temperature = temperature
self.max_tokens = max_tokens
self.max_retries = max_retries
self.top_p = top_p
self.max_workers = max_workers
self.timeout = timeout
self.kwargs = kwargs
self.logger = get_logger()
# Get API key from environment variable
self.api_key = os.environ.get(key_name_of_api_key)
if self.api_key is None:
error_msg = f"Lack of `{key_name_of_api_key}` in environment variables. Please set `{key_name_of_api_key}` as your api-key before using LiteLLMServing."
self.logger.error(error_msg)
raise ValueError(error_msg)
self.key_name_of_api_key = key_name_of_api_key
if custom_llm_provider is not None:
self.custom_llm_provider = custom_llm_provider
# Validate model by making a test call
self._validate_setup()
self.logger.info(f"LiteLLMServing initialized with model: {model_name}")
def switch_model(self,
model_name: str,
key_name_of_api_key: Optional[str] = None,
api_url: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs: Any):
"""
Switch to a different model with potentially different API configuration.
Args:
model_name: Model name to switch to
key_name_of_api_key: New environment variable name for API key (optional)
api_url: New API base URL (optional)
api_version: New API version (optional)
**kwargs: Additional parameters for the new model
"""
# Update model
self.model_name = model_name
# Update API key if new environment variable provided
if key_name_of_api_key is not None:
self.api_key = os.environ.get(key_name_of_api_key)
if self.api_key is None:
error_msg = f"Lack of `{key_name_of_api_key}` in environment variables. Please set `{key_name_of_api_key}` as your api-key before switching model."
self.logger.error(error_msg)
raise ValueError(error_msg)
self.key_name_of_api_key = key_name_of_api_key
# Update other API configuration if provided
if api_url is not None:
self.api_url = api_url
if api_version is not None:
self.api_version = api_version
# Update other parameters from kwargs
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
else:
self.kwargs[key] = value
# Validate the new configuration
self._validate_setup()
self.logger.success(f"Switched to model: {model_name}")
def format_response(self, response: Dict[str, Any]) -> str:
"""
Format LiteLLM response to include reasoning content in a structured format.
This method handles the standardized LiteLLM response format and extracts
both the main content and any reasoning_content if available.
Args:
response: The response dictionary from LiteLLM
Returns:
Formatted string with think/answer tags if reasoning is present,
otherwise just the content
"""
try:
# Extract the main content
content = response['choices'][0]['message']['content']
# Check if content already has think/answer format
if re.search(r'<think>.*</think>.*<answer>.*</answer>', content, re.DOTALL):
return content
# Try to extract reasoning_content from LiteLLM standardized format
reasoning_content = ""
try:
# LiteLLM provides reasoning_content in the message object
message = response['choices'][0]['message']
if hasattr(message, 'reasoning_content') and message.reasoning_content:
reasoning_content = message.reasoning_content
elif isinstance(message, dict) and 'reasoning_content' in message:
reasoning_content = message['reasoning_content']
except (KeyError, AttributeError):
pass
# Format the response based on whether reasoning content exists
if reasoning_content:
return f"<think>{reasoning_content}</think>\n<answer>{content}</answer>"
else:
return content
except (KeyError, IndexError) as e:
self.logger.error(f"Error formatting response: {e}")
# Return original response as string if formatting fails
return str(response)
def _validate_setup(self):
"""Validate the model and API configuration."""
try:
# Prepare completion parameters
completion_params = {
"model": self.model_name,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 1,
"timeout": self.timeout
}
# Add optional parameters if provided
if self.api_key:
completion_params["api_key"] = self.api_key
if self.api_url:
completion_params["api_base"] = self.api_url
if self.api_version:
completion_params["api_version"] = self.api_version
if hasattr(self, "custom_llm_provider"):
completion_params["custom_llm_provider"] = self.custom_llm_provider
# Make a minimal test call to validate setup
response = self._litellm.completion(**completion_params)
self.logger.success("LiteLLM setup validation successful")
except Exception as e:
self.logger.error(f"LiteLLM setup validation failed: {e}")
raise ValueError(f"Failed to validate LiteLLM setup: {e}")
def _generate_single(self, user_input: str, system_prompt: str, json_schema: dict = None) -> str:
"""Generate response for a single input with retry logic.
Args:
user_input: User input text
system_prompt: System prompt
Returns:
Generated response string
Raises:
Exception: If generation fails after all retries
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
]
# Prepare completion parameters
completion_params = {
"model": self.model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"timeout": self.timeout,
**self.kwargs
}
# Add optional parameters if provided
if self.api_key:
completion_params["api_key"] = self.api_key
if self.api_url:
completion_params["api_base"] = self.api_url
if self.api_version:
completion_params["api_version"] = self.api_version
if json_schema is not None:
completion_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "structural_response",
"strict": True,
"schema": json_schema
}
}
if hasattr(self, "custom_llm_provider"):
completion_params["custom_llm_provider"] = self.custom_llm_provider
last_error = None
for attempt in range(self.max_retries):
try:
response = self._litellm.completion(**completion_params)
# Convert response to dict format for format_response
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response.dict()
return self.format_response(response_dict)
except Exception as e:
last_error = e
if attempt < self.max_retries - 1:
# Check if error is retryable
error_str = str(e).lower()
if any(retryable in error_str for retryable in
["rate limit", "timeout", "connection", "503", "502", "429"]):
wait_time = min(2 ** attempt, 10) # Exponential backoff with max 10s
self.logger.warning(f"Retryable error, waiting {wait_time}s: {e}")
time.sleep(wait_time)
continue
# Non-retryable error or last attempt
self.logger.error(f"Error generating response (attempt {attempt + 1}/{self.max_retries}): {e}")
break
# Raise the last error instead of returning error string
raise last_error
def generate_from_input(self,
user_inputs: List[str],
system_prompt: str = "You are a helpful assistant",
json_schema: dict = None,
) -> List[str]:
"""
Generate responses for a list of inputs using concurrent processing.
Args:
user_inputs: List of user input strings
system_prompt: System prompt to use for all generations
Returns:
List of generated responses
"""
if not user_inputs:
return []
# Single input case
if len(user_inputs) == 1:
try:
return [self._generate_single(user_inputs[0], system_prompt, json_schema)]
except Exception as e:
# For consistency with batch processing, return error message in list
error_msg = f"Error: {str(e)}"
self.logger.error(f"Failed to generate response: {e}")
return [error_msg]
# Batch processing with threading
responses = [None] * len(user_inputs)
def generate_with_index(idx: int, user_input: str) -> Tuple[int, str]:
try:
response = self._generate_single(user_input, system_prompt,json_schema)
return idx, response
except Exception as e:
# For batch processing, return error message to maintain list structure
error_msg = f"Error: {str(e)}"
self.logger.error(f"Failed to generate response for input {idx}: {e}")
return idx, error_msg
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(generate_with_index, idx, user_input)
for idx, user_input in enumerate(user_inputs)
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating"):
idx, response = future.result()
responses[idx] = response
return responses
def generate_embedding_from_input(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors
"""
if not texts:
return []
embeddings = []
# Prepare embedding parameters
embedding_params = {
"model": self.model_name,
"timeout": self.timeout
}
# Add optional parameters if provided
if self.api_key:
embedding_params["api_key"] = self.api_key
if self.api_url:
embedding_params["api_base"] = self.api_url
if self.api_version:
embedding_params["api_version"] = self.api_version
# Process embeddings with retry logic
def embed_with_retry(text: str):
last_error = None
for attempt in range(self.max_retries):
try:
response = self._litellm.embedding(
input=[text],
**embedding_params
)
return response['data'][0]['embedding']
except Exception as e:
last_error = e
if attempt < self.max_retries - 1:
error_str = str(e).lower()
if any(retryable in error_str for retryable in
["rate limit", "timeout", "connection", "503", "502", "429"]):
wait_time = min(2 ** attempt, 10)
self.logger.warning(f"Retryable error in embedding, waiting {wait_time}s: {e}")
time.sleep(wait_time)
continue
self.logger.error(f"Error generating embedding (attempt {attempt + 1}/{self.max_retries}): {e}")
break
raise last_error
# Process in batches for better performance
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(embed_with_retry, text) for text in texts]
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating embeddings"):
try:
embedding = future.result()
embeddings.append(embedding)
except Exception as e:
self.logger.error(f"Failed to generate embedding: {e}")
# Return empty embedding for failed cases to maintain list structure
embeddings.append([])
return embeddings
def get_supported_models(self) -> List[str]:
"""Get list of supported models for the current provider."""
try:
return self._litellm.model_list
except Exception as e:
self.logger.warning(f"Could not retrieve model list: {e}")
return []
def cleanup(self) -> None:
"""Cleanup resources."""
self.logger.info("Cleaning up LiteLLMServing resources")
# LiteLLM doesn't require explicit cleanup since we don't use global state
# Instance variables will be garbage collected when the instance is destroyed
# Clear any references to ensure proper cleanup
self.api_key = None
self.kwargs = None
import os
import torch
import contextlib
import time
from typing import Optional, Union, List, Dict, Any
from dataflow import get_logger
from huggingface_hub import snapshot_download
from dataflow.core import LLMServingABC
from transformers import AutoTokenizer
class LocalModelLLMServing_vllm(LLMServingABC):
'''
A class for generating text using vllm, with model from huggingface or local directory
'''
def __init__(self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.7,
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_top_k: int = 40,
vllm_repetition_penalty: float = 1.0,
vllm_seed: int = None,
vllm_max_model_len: int = None,
vllm_gpu_memory_utilization: float=0.9,
):
self.logger = get_logger()
self.load_model(
hf_model_name_or_path=hf_model_name_or_path,
hf_cache_dir=hf_cache_dir,
hf_local_dir=hf_local_dir,
vllm_tensor_parallel_size=vllm_tensor_parallel_size,
vllm_temperature=vllm_temperature,
vllm_top_p=vllm_top_p,
vllm_max_tokens=vllm_max_tokens,
vllm_top_k=vllm_top_k,
vllm_repetition_penalty=vllm_repetition_penalty,
vllm_seed=vllm_seed,
vllm_max_model_len=vllm_max_model_len,
vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
)
self.backend_initialized = False
def load_model(self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.7,
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_top_k: int = 40,
vllm_repetition_penalty: float = 1.0,
vllm_seed: int = 42,
vllm_max_model_len: int = None,
vllm_gpu_memory_utilization: float=0.9,
):
self.hf_model_name_or_path = hf_model_name_or_path
self.hf_cache_dir = hf_cache_dir
self.hf_local_dir = hf_local_dir
self.vllm_tensor_parallel_size = vllm_tensor_parallel_size
self.vllm_temperature = vllm_temperature
self.vllm_top_p = vllm_top_p
self.vllm_max_tokens = vllm_max_tokens
self.vllm_top_k = vllm_top_k
self.vllm_repetition_penalty = vllm_repetition_penalty
self.vllm_seed = vllm_seed
self.vllm_max_model_len = vllm_max_model_len
self.vllm_gpu_memory_utilization = vllm_gpu_memory_utilization
def start_serving(self):
self.backend_initialized = True
self.logger = get_logger()
if self.hf_model_name_or_path is None:
raise ValueError("hf_model_name_or_path is required")
elif os.path.exists(self.hf_model_name_or_path):
self.logger.info(f"Using local model path: {self.hf_model_name_or_path}")
self.real_model_path = self.hf_model_name_or_path
else:
self.logger.info(f"Downloading model from HuggingFace: {self.hf_model_name_or_path}")
self.real_model_path = snapshot_download(
repo_id=self.hf_model_name_or_path,
cache_dir=self.hf_cache_dir,
local_dir=self.hf_local_dir,
)
# Import vLLM and set up the environment for multiprocessing
# vLLM requires the multiprocessing method to be set to spawn
try:
from vllm import LLM, SamplingParams
except ValueError as ve:
# A ValueError typically indicates a transformers version mismatch
raise ImportError(
"Failed to import vllm due to a ValueError: this is often caused by a transformers version conflict. "
"Please check your transformers package and upgrade or downgrade it to the version required by vllm."
) from ve
except ImportError as ie:
# vllm is not installed
raise ImportError(
"vllm is not installed. Please install it by running:\n"
" pip install open-dataflow[vllm]\n"
"If it is already installed, ensure that the installation environment matches your current runtime environment."
) from ie
# Set the environment variable for vllm to use spawn method for multiprocessing
# See https://docs.vllm.ai/en/v0.7.1/design/multiprocessing.html
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "spawn"
self.sampling_params = SamplingParams(
temperature=self.vllm_temperature,
top_p=self.vllm_top_p,
max_tokens=self.vllm_max_tokens,
top_k=self.vllm_top_k,
repetition_penalty=self.vllm_repetition_penalty,
seed=self.vllm_seed
)
self.llm = LLM(
model=self.real_model_path,
tensor_parallel_size=self.vllm_tensor_parallel_size,
max_model_len=self.vllm_max_model_len,
gpu_memory_utilization=self.vllm_gpu_memory_utilization,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.real_model_path, cache_dir=self.hf_cache_dir)
self.logger.success(f"Model loaded from {self.real_model_path} by vLLM backend")
def generate_from_input(self,
user_inputs: list[str],
system_prompt: str = "You are a helpful assistant",
json_schema: dict = None,
) -> list[str]:
if not self.backend_initialized:
self.start_serving()
full_prompts = []
for question in user_inputs:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
]
full_prompts.append(messages)
full_template = self.tokenizer.apply_chat_template(
full_prompts,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True, # Set to False to strictly disable thinking
)
if json_schema is not None:
try:
from vllm import SamplingParams
from vllm.sampling_params import GuidedDecodingParams
except:
raise ImportError("please install vllm first like 'pip install open-dataflow[vllm]'")
guided_decoding_params = GuidedDecodingParams(
json=json_schema
)
self.sampling_params = SamplingParams(
temperature=self.vllm_temperature,
top_p=self.vllm_top_p,
max_tokens=self.vllm_max_tokens,
top_k=self.vllm_top_k,
repetition_penalty=self.vllm_repetition_penalty,
seed=self.vllm_seed,
guided_decoding=guided_decoding_params
)
responses = self.llm.generate(full_template, self.sampling_params)
return [output.outputs[0].text for output in responses]
def generate_embedding_from_input(self, texts: list[str]) -> list[list[float]]:
if not self.backend_initialized:
self.start_serving()
outputs = self.llm.embed(texts)
return [output.outputs.embedding for output in outputs]
def cleanup(self):
free_mem = torch.cuda.mem_get_info()[0] # 返回可用显存(单位:字节)
total_mem = torch.cuda.get_device_properties(0).total_memory
self.logger.info(f"Free memory: {free_mem / (1024 ** 2):.2f} MB / {total_mem / (1024 ** 2):.2f} MB")
self.logger.info("Cleaning up vLLM backend resources...")
self.backend_initialized = False
from vllm.distributed.parallel_state import (
destroy_model_parallel,
destroy_distributed_environment,
)
del self.llm.llm_engine
del self.llm
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
import gc
gc.collect()
torch.cuda.empty_cache()
import ray
ray.shutdown()
free_mem = torch.cuda.mem_get_info()[0] # 返回可用显存(单位:字节)
total_mem = torch.cuda.get_device_properties(0).total_memory
self.logger.info(f"Free memory: {free_mem / (1024 ** 2):.2f} MB / {total_mem / (1024 ** 2):.2f} MB")
class LocalModelLLMServing_sglang(LLMServingABC):
def __init__(
self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
sgl_tp_size: int = 1, # tensor parallel size
sgl_dp_size: int = 1, # data parallel size
sgl_mem_fraction_static: float = 0.9, # memory fraction for static memory allocation
sgl_max_new_tokens: int = 2048, # maximum number of new tokens to generate
sgl_stop: Optional[Union[str, List[str]]] = None,
sgl_stop_token_ids: Optional[List[int]] = None,
sgl_temperature: float = 1.0,
sgl_top_p: float = 1.0,
sgl_top_k: int = -1,
sgl_min_p: float = 0.0,
sgl_frequency_penalty: float = 0.0,
sgl_presence_penalty: float = 0.0,
sgl_repetition_penalty: float = 1.0,
sgl_min_new_tokens: int = 0,
sgl_n: int = 1,
sgl_json_schema: Optional[str] = None,
sgl_regex: Optional[str] = None,
sgl_ebnf: Optional[str] = None,
sgl_structural_tag: Optional[str] = None,
sgl_ignore_eos: bool = False,
sgl_skip_special_tokens: bool = True,
sgl_spaces_between_special_tokens: bool = True,
sgl_no_stop_trim: bool = False,
sgl_custom_params: Optional[Dict[str, Any]] = None,
sgl_stream_interval: Optional[int] = None,
sgl_logit_bias: Optional[Dict[str, float]] = None,
):
self.logger = get_logger()
self.load_model(
hf_model_name_or_path=hf_model_name_or_path,
hf_cache_dir=hf_cache_dir,
hf_local_dir=hf_local_dir,
sgl_tp_size=sgl_tp_size,
sgl_dp_size=sgl_dp_size,
sgl_mem_fraction_static=sgl_mem_fraction_static, # memory fraction for static
sgl_max_new_tokens=sgl_max_new_tokens,
sgl_stop=sgl_stop,
sgl_stop_token_ids=sgl_stop_token_ids,
sgl_temperature=sgl_temperature,
sgl_top_p=sgl_top_p,
sgl_top_k=sgl_top_k,
sgl_min_p=sgl_min_p,
sgl_frequency_penalty=sgl_frequency_penalty,
sgl_presence_penalty=sgl_presence_penalty,
sgl_repetition_penalty=sgl_repetition_penalty,
sgl_min_new_tokens=sgl_min_new_tokens,
sgl_n=sgl_n,
sgl_json_schema=sgl_json_schema,
sgl_regex=sgl_regex,
sgl_ebnf=sgl_ebnf,
sgl_structural_tag=sgl_structural_tag,
sgl_ignore_eos=sgl_ignore_eos,
sgl_skip_special_tokens=sgl_skip_special_tokens,
sgl_spaces_between_special_tokens=sgl_spaces_between_special_tokens,
sgl_no_stop_trim=sgl_no_stop_trim,
sgl_custom_params=sgl_custom_params,
sgl_stream_interval=sgl_stream_interval,
sgl_logit_bias=sgl_logit_bias,
)
self.backend_initialized = False
def load_model(
self,
hf_model_name_or_path:str = None,
hf_cache_dir:str = None,
hf_local_dir:str = None,
sgl_tp_size: int = 1,
sgl_dp_size: int = 1,
sgl_mem_fraction_static: float = 0.9, # memory fraction for static memory allocation
sgl_max_new_tokens: int = 2048,
sgl_stop: Optional[Union[str, List[str]]] = None,
sgl_stop_token_ids: Optional[List[int]] = None,
sgl_temperature: float = 1.0,
sgl_top_p: float = 1.0,
sgl_top_k: int = -1,
sgl_min_p: float = 0.0,
sgl_frequency_penalty: float = 0.0,
sgl_presence_penalty: float = 0.0,
sgl_repetition_penalty: float = 1.0,
sgl_min_new_tokens: int = 0,
sgl_n: int = 1,
sgl_json_schema: Optional[str] = None,
sgl_regex: Optional[str] = None,
sgl_ebnf: Optional[str] = None,
sgl_structural_tag: Optional[str] = None,
sgl_ignore_eos: bool = False,
sgl_skip_special_tokens: bool = True,
sgl_spaces_between_special_tokens: bool = True,
sgl_no_stop_trim: bool = False,
sgl_custom_params: Optional[Dict[str, Any]] = None,
sgl_stream_interval: Optional[int] = None,
sgl_logit_bias: Optional[Dict[str, float]] = None,
):
self.hf_model_name_or_path = hf_model_name_or_path
self.hf_cache_dir = hf_cache_dir
self.hf_local_dir = hf_local_dir
self.sgl_tp_size = sgl_tp_size
self.sgl_dp_size = sgl_dp_size
self.sgl_mem_fraction_static = sgl_mem_fraction_static
self.sgl_max_new_tokens = sgl_max_new_tokens
self.sgl_stop = sgl_stop
self.sgl_stop_token_ids = sgl_stop_token_ids
self.sgl_temperature = sgl_temperature
self.sgl_top_p = sgl_top_p
self.sgl_top_k = sgl_top_k
self.sgl_min_p = sgl_min_p
self.sgl_frequency_penalty = sgl_frequency_penalty
self.sgl_presence_penalty = sgl_presence_penalty
self.sgl_repetition_penalty = sgl_repetition_penalty
self.sgl_min_new_tokens = sgl_min_new_tokens
self.sgl_n = sgl_n
self.sgl_json_schema = sgl_json_schema
self.sgl_regex = sgl_regex
self.sgl_ebnf = sgl_ebnf
self.sgl_structural_tag = sgl_structural_tag
self.sgl_ignore_eos = sgl_ignore_eos
self.sgl_skip_special_tokens = sgl_skip_special_tokens
self.sgl_spaces_between_special_tokens = sgl_spaces_between_special_tokens
self.sgl_no_stop_trim = sgl_no_stop_trim
self.sgl_custom_params = sgl_custom_params
self.sgl_stream_interval = sgl_stream_interval
self.sgl_logit_bias = sgl_logit_bias
def start_serving(self):
self.backend_initialized = True
self.logger = get_logger()
if self.hf_model_name_or_path is None:
raise ValueError("hf_model_name_or_path is required")
elif os.path.exists(self.hf_model_name_or_path):
self.logger.info(f"Using local model path: {self.hf_model_name_or_path}")
self.real_model_path = self.hf_model_name_or_path
else:
self.logger.info(f"Downloading model from HuggingFace: {self.hf_model_name_or_path}")
self.real_model_path = snapshot_download(
repo_id=self.hf_model_name_or_path,
cache_dir=self.hf_cache_dir,
local_dir=self.hf_local_dir,
)
# import sglang and set up the environment for multiprocessing
try:
import sglang as sgl
except ImportError:
raise ImportError("please install sglang first like 'pip install open-dataflow[sglang]'")
self.llm = sgl.Engine(
model_path=self.real_model_path,
tp_size=self.sgl_tp_size,
dp_size=self.sgl_dp_size,
mem_fraction_static=self.sgl_mem_fraction_static, # memory fraction for static memory allocation
)
self.sampling_params = {
"max_new_tokens": self.sgl_max_new_tokens,
"stop": self.sgl_stop,
"stop_token_ids": self.sgl_stop_token_ids,
"temperature": self.sgl_temperature,
"top_p": self.sgl_top_p,
"top_k": self.sgl_top_k,
"min_p": self.sgl_min_p,
"frequency_penalty": self.sgl_frequency_penalty,
"presence_penalty": self.sgl_presence_penalty,
"repetition_penalty": self.sgl_repetition_penalty,
"min_new_tokens": self.sgl_min_new_tokens,
"n": self.sgl_n,
"json_schema": self.sgl_json_schema,
"regex": self.sgl_regex,
"ebnf":self.sgl_ebnf,
"structural_tag": self.sgl_structural_tag,
"ignore_eos": self.sgl_ignore_eos,
"skip_special_tokens": self.sgl_skip_special_tokens,
"spaces_between_special_tokens": self.sgl_spaces_between_special_tokens,
"no_stop_trim": self.sgl_no_stop_trim,
"custom_params": self.sgl_custom_params,
"stream_interval":self.sgl_stream_interval,
"logit_bias": self.sgl_logit_bias,
}
# remove all keys equal to None
self.sampling_params = {k: v for k, v in self.sampling_params.items() if v is not None}
self.tokenizer = AutoTokenizer.from_pretrained(self.real_model_path, cache_dir=self.hf_cache_dir)
self.logger.success(f"Model loaded from {self.real_model_path} by SGLang backend")
def generate_from_input(self,
user_inputs: list[str],
system_prompt: str = "You are a helpful assistant"
) -> list[str]:
if not self.backend_initialized:
self.start_serving()
full_prompts = []
for question in user_inputs:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
]
full_prompts.append(messages)
full_template = self.tokenizer.apply_chat_template(
full_prompts,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True, # Set to False to strictly disable thinking
)
try:
responses = self.llm.generate(full_template, self.sampling_params)
except Exception as e:
self.logger.error(f"Error during Sglang Backend generation, please check your parameters.: {e}")
raise e
return [output['text'] for output in responses]
def generate_embedding_from_input(self, texts: list[str]) -> list[list[float]]:
raise NotImplementedError("SGLang backend does not support embedding generation yet. If you have experience with SGLang, please contribute to this feature in Pull Request.")
# if not self.backend_initialized:
# self.start_serving()
# self.llm.
# outputs = self.llm.embed(texts)
# return [output['embedding'] for output in outputs]
def cleanup(self):
self.logger.info("Cleaning up SGLang backend resources...")
self.backend_initialized = False
self.llm.shutdown()
del self.llm
import gc;
gc.collect()
torch.cuda.empty_cache()
\ No newline at end of file
import os
import torch
import contextlib
from typing import List, Optional, Union, Dict, Any
from PIL import Image
import re
from dataflow import get_logger
from huggingface_hub import snapshot_download
from dataflow.core import LLMServingABC
from transformers import AutoTokenizer
class LocalVLMServing_vllm(LLMServingABC):
"""
Client for serving a Vision-Language Model (VLM) locally using vLLM.
Combines the interface of APIVLMServing with the backend efficiency of vLLM.
"""
def __init__(self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.2, # Lower temperature is usually better for VLM tasks
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_max_model_len: int = None,
vllm_top_k: int = 40,
vllm_repetition_penalty: float = 1.0,
vllm_seed: int = None,
vllm_gpu_memory_utilization: float = 0.9,
vllm_limit_mm_per_prompt: int = 1, # Specific to VLMs: max images per prompt
trust_remote_code: bool = True,
enable_thinking:bool =True, # Set to False to strictly disable thinking
batch_size: int = 128
):
"""
Initialize the Local VLM Serving client.
"""
self.logger = get_logger()
QWEN_VL_PATTERN = re.compile(r'Qwen-VL|Qwen[0-9\.]+-VL')
# 报Warning显示目前经过测试的VLM主要是QwenVL
if hf_model_name_or_path and not QWEN_VL_PATTERN.search(hf_model_name_or_path):
self.logger.warning(
"Model Compatibility Alert: LocalVLMServing_vllm is primarily tested with Qwen-VL models "
"(e.g., Qwen-VL-Chat, Qwen2.5-VL-Chat). Other VLMs may require additional adjustments "
"for correct functionality."
)
self.load_model(
hf_model_name_or_path=hf_model_name_or_path,
hf_cache_dir=hf_cache_dir,
hf_local_dir=hf_local_dir,
vllm_tensor_parallel_size=vllm_tensor_parallel_size,
vllm_temperature=vllm_temperature,
vllm_top_p=vllm_top_p,
vllm_max_tokens=vllm_max_tokens,
vllm_top_k=vllm_top_k,
vllm_repetition_penalty=vllm_repetition_penalty,
vllm_seed=vllm_seed,
vllm_max_model_len=vllm_max_model_len,
vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
vllm_limit_mm_per_prompt=vllm_limit_mm_per_prompt,
trust_remote_code=trust_remote_code,
enable_thinking=enable_thinking
)
self.backend_initialized = False
self.batch_size = batch_size
def load_model(self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.2,
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_max_model_len: int = None,
vllm_top_k: int = 40,
vllm_repetition_penalty: float = 1.0,
vllm_seed: int = None,
vllm_gpu_memory_utilization: float = 0.9,
vllm_limit_mm_per_prompt: int = 1,
trust_remote_code: bool = True,
enable_thinking:bool =True,
):
self.hf_model_name_or_path = hf_model_name_or_path
self.hf_cache_dir = hf_cache_dir
self.hf_local_dir = hf_local_dir
self.vllm_tensor_parallel_size = vllm_tensor_parallel_size
self.vllm_temperature = vllm_temperature
self.vllm_top_p = vllm_top_p
self.vllm_max_tokens = vllm_max_tokens
self.vllm_max_model_len = vllm_max_model_len
self.vllm_top_k = vllm_top_k
self.vllm_repetition_penalty = vllm_repetition_penalty
self.vllm_seed = vllm_seed
self.vllm_gpu_memory_utilization = vllm_gpu_memory_utilization
self.vllm_limit_mm_per_prompt = vllm_limit_mm_per_prompt
self.trust_remote_code = trust_remote_code
self.enable_thinking = enable_thinking
def start_serving(self):
self.backend_initialized = True
self.logger = get_logger()
# 1. Handle Model Path (HuggingFace or Local)
if self.hf_model_name_or_path is None:
raise ValueError("hf_model_name_or_path is required")
elif os.path.exists(self.hf_model_name_or_path):
self.logger.info(f"Using local model path: {self.hf_model_name_or_path}")
self.real_model_path = self.hf_model_name_or_path
else:
self.logger.info(f"Downloading model from HuggingFace: {self.hf_model_name_or_path}")
self.real_model_path = snapshot_download(
repo_id=self.hf_model_name_or_path,
cache_dir=self.hf_cache_dir,
local_dir=self.hf_local_dir,
)
# 2. Import vLLM and Setup Environment
try:
from vllm import LLM, SamplingParams
except ImportError:
raise ImportError("please install vllm first like 'pip install open-dataflow[vllm]'")
# Set environment for multiprocessing compatibility
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "spawn"
self.sampling_params = SamplingParams(
temperature=self.vllm_temperature,
top_p=self.vllm_top_p,
max_tokens=self.vllm_max_tokens,
top_k=self.vllm_top_k,
repetition_penalty=self.vllm_repetition_penalty,
seed=self.vllm_seed
)
# 3. Initialize LLM Engine with VLM specific params
self.llm = LLM(
model=self.real_model_path,
tensor_parallel_size=self.vllm_tensor_parallel_size,
max_model_len=self.vllm_max_model_len,
gpu_memory_utilization=self.vllm_gpu_memory_utilization,
limit_mm_per_prompt={"image": self.vllm_limit_mm_per_prompt}, # Specific config for image limits
trust_remote_code=self.trust_remote_code
)
# Load tokenizer for chat templating
self.tokenizer = AutoTokenizer.from_pretrained(
self.real_model_path,
cache_dir=self.hf_cache_dir,
trust_remote_code=self.trust_remote_code,
enable_thinking=self.enable_thinking
)
self.logger.success(f"VLM Model loaded from {self.real_model_path} by vLLM backend")
def _load_image(self, image_path: str) -> Image.Image:
"""
Helper to load image from path using PIL.
Replaces API _encode_image_to_base64 logic with PIL loading for vLLM.
"""
try:
return Image.open(image_path).convert("RGB")
except Exception as e:
self.logger.error(f"Failed to load image at {image_path}: {e}")
raise e
def _run_batch_inference(self, vllm_inputs, batch_size):
all_outputs = []
# 按 batch_size 分批处理
for i in range(0, len(vllm_inputs), batch_size):
batch_inputs = vllm_inputs[i:i+batch_size]
outputs = self.llm.generate(batch_inputs, sampling_params=self.sampling_params)
all_outputs.extend([output.outputs[0].text for output in outputs])
return all_outputs
def generate_from_input_one_image(
self,
image_paths: List[str],
text_prompts: List[str],
system_prompt: str = "Describe the image in detail.",
model: str = None, # Unused, kept for interface compatibility
timeout: int = 1800 # Unused, kept for interface compatibility
) -> List[str]:
"""
Batch process single-image chat requests concurrently using vLLM.
Matches the signature of APIVLMServing_openai.generate_from_input_one_image.
"""
if not self.backend_initialized:
self.start_serving()
if len(image_paths) != len(text_prompts):
raise ValueError("`image_paths` and `text_prompts` must have the same length")
inputs = []
# Prepare inputs for vLLM
for img_path, user_text in zip(image_paths, text_prompts):
image = self._load_image(img_path)
# Construct messages using standard chat format
# Note: Specific VLMs might require specific placeholder tokens (e.g., <image>)
# but AutoTokenizer.apply_chat_template usually handles this if configured correctly.
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": user_text}
]}
]
# Apply template to get the prompt string
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Construct vLLM input dictionary
inputs.append({
"prompt": prompt,
"multi_modal_data": {
"image": image
}
})
# Run Inference
outputs = self.llm.generate(inputs, sampling_params=self.sampling_params)
# Extract Text
responses = []
for output in outputs:
responses.append(output.outputs[0].text)
return responses
def generate_from_input_multi_images(
self,
list_of_image_paths: List[List[str]],
list_of_image_labels: List[List[str]],
user_prompts: List[str],
system_prompt: str = "Analyze the provided images.",
model: str = None, # 保持接口一致性,实际不使用
timeout: int = 1800 # 保持接口一致性,实际不使用
) -> List[str]:
"""
Batch process multi-image chat requests using vLLM.
:param list_of_image_paths: Outer list is the batch, inner list contains paths for one request.
:param list_of_image_labels: Corresponding labels/prompts for each image in the inner list.
"""
if not self.backend_initialized:
self.start_serving()
if len(list_of_image_paths) != len(list_of_image_labels):
raise ValueError("`list_of_image_paths` and `list_of_image_labels` must have the same length")
vllm_inputs = []
# 遍历每一个请求 (Batch Loop)
for paths, labels, user_prompt in zip(list_of_image_paths, list_of_image_labels, user_prompts):
if len(paths) != len(labels):
raise ValueError("Inner lists of paths and labels must have the same length")
# 1. 加载该请求下的所有图片
# vLLM 支持传入 PIL Image 列表
current_images = [self._load_image(p) for p in paths]
# 检查是否超过了初始化时设定的最大图片数
if len(current_images) > self.vllm_limit_mm_per_prompt:
self.logger.warning(
f"Request contains {len(current_images)} images, but limit is {self.vllm_limit_mm_per_prompt}. "
"This might cause vLLM errors. Increase `vllm_limit_mm_per_prompt` in init."
)
# 2. 构建 User Content
# 我们需要交替插入文本(Label)和图片占位符
user_content = [{"type": "text", "text": user_prompt}]
for label, _ in zip(labels, paths):
# 插入标签文本(如果有)
if label:
user_content.append({"type": "text", "text": f"{label}\n"})
# 插入图片占位符
user_content.append({"type": "image"})
# 3. 构建完整的 Messages
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content}
]
# 4. 应用 Chat Template
# 大多数现代 VLM 的 template 会自动处理多个 {"type": "image"}
# 并将其转换为类似 <image> <image> ... 的 token 序列
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 5. 构建 vLLM 输入 payload
# 当有多张图片时,multi_modal_data["image"] 应该是一个列表
vllm_inputs.append({
"prompt": prompt,
"multi_modal_data": {
"image": current_images
}
})
# 6. 执行批量推理
all_outputs = self._run_batch_inference(vllm_inputs, batch_size=self.batch_size)
return all_outputs
def cleanup(self):
"""
Clean up vLLM backend resources.
Identical logic to LocalModelLLMServing_vllm.cleanup.
"""
free_mem = torch.cuda.mem_get_info()[0]
total_mem = torch.cuda.get_device_properties(0).total_memory
self.logger.info(f"Free memory before cleanup: {free_mem / (1024 ** 2):.2f} MB / {total_mem / (1024 ** 2):.2f} MB")
self.logger.info("Cleaning up vLLM VLM backend resources...")
self.backend_initialized = False
# vLLM Distributed Cleanup
from vllm.distributed.parallel_state import (
destroy_model_parallel,
destroy_distributed_environment,
)
if hasattr(self, 'llm'):
del self.llm.llm_engine
del self.llm
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
import gc
import ray
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()
free_mem = torch.cuda.mem_get_info()[0]
self.logger.info(f"Free memory after cleanup: {free_mem / (1024 ** 2):.2f} MB")
def generate_from_input(self, user_inputs: List[str], system_prompt: str = "Describe the image in detail.", json_schema: dict = None):
"""
保持接口一致性,实际不使用
"""
return self.generate_from_input_one_image(
image_paths=user_inputs,
text_prompts=[""] * len(user_inputs),
system_prompt=system_prompt
)
\ No newline at end of file
import subprocess
import signal
import os
import re
import time
import requests
from threading import Thread
from tqdm import tqdm
from dataflow import get_logger
from dataflow.core import LLMServingABC
from concurrent.futures import ThreadPoolExecutor, as_completed
class LocalHostLLMAPIServing_vllm(LLMServingABC):
"""
A class to serve vLLM via a subprocess (e.g., localhost API server)
"""
def __init__(self,
hf_model_name_or_path: str,
hf_cache_dir: str = None,
max_workers: int = 16,
vllm_server_port: int = 12345,
vllm_server_host: str = "127.0.0.1",
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.7,
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_top_k: int = 40,
vllm_max_model_len: int = None,
vllm_gpu_memory_utilization: float = 0.9,
vllm_server_start_timeout: int = 120,
):
self.logger = get_logger()
self.hf_model_name_or_path = hf_model_name_or_path
self.port = vllm_server_port
self.host = vllm_server_host
self.tensor_parallel_size = vllm_tensor_parallel_size
self.temperature = vllm_temperature
self.top_p = vllm_top_p
self.max_tokens = vllm_max_tokens
self.top_k = vllm_top_k
self.max_model_len = vllm_max_model_len
self.gpu_memory_utilization = vllm_gpu_memory_utilization
self.hf_cache_dir = hf_cache_dir
self.server_start_timeout = vllm_server_start_timeout
self.max_workers = max_workers
self.process = None
self.backend_initialized = False
def _stream_subprocess_logs(self, pipe):
"""
持续读取子进程输出,只保留 INFO/ERROR/Traceback
"""
traceback_mode = False
is_keyboard_interrupted = False
traceback_info_list = []
for line in iter(pipe.readline, ''):
line = line.rstrip("\n")
if not line:
continue
# 判断traceback
if line.startswith("Traceback"):
traceback_mode = True
traceback_info_list.append(line)
if "KeyboardInterrupt: MQLLMEngine terminated" in line:
is_keyboard_interrupted = True
continue
# 如果处于traceback模式,输出所有行直到空行结束
if traceback_mode:
traceback_info_list.append(line)
if "MQLLMEngine terminated" in line:
is_keyboard_interrupted = True
if line == "":
traceback_mode = False
continue
# 仅保留 INFO 和 ERROR
if re.match(r"^(INFO|ERROR):", line):
if "INFO:" in line:
if "POST" in line:
self.logger.debug(line)
else:
self.logger.info(line)
else:
self.logger.error(line)
if is_keyboard_interrupted:
self.logger.success("MQLLMEngine terminated")
else:
for log in traceback_info_list:
print(log)
self.is_error = not is_keyboard_interrupted
def start_serving(self):
if self.backend_initialized:
self.logger.info("vLLM server already running.")
return
command = [
"python", "-m", "vllm.entrypoints.openai.api_server",
"--model", self.hf_model_name_or_path,
"--tensor-parallel-size", str(self.tensor_parallel_size),
"--port", str(self.port),
"--gpu-memory-utilization", str(self.gpu_memory_utilization),
]
if self.max_model_len:
command += ["--max-model-len", str(self.max_model_len)]
if self.hf_cache_dir:
command += ["--download-dir", self.hf_cache_dir]
self.logger.info(f"Starting vLLM server with command: {' '.join(command)}")
self.process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
preexec_fn=os.setsid
)
# 后台线程处理日志
Thread(target=self._stream_subprocess_logs, args=(self.process.stdout,), daemon=True).start()
Thread(target=self._stream_subprocess_logs, args=(self.process.stderr,), daemon=True).start()
for i in range(self.server_start_timeout): # 增加等待时间
if hasattr(self, "is_error") and self.is_error:
break
try:
response = requests.get(f"http://{self.host}:{self.port}/v1/models", timeout=1.0)
status = response.status_code
self.logger.debug(f"[{i+1}/{self.server_start_timeout}] Status: {status}")
if status == 200:
self.backend_initialized = True
self.logger.success("vLLM server started successfully!")
return
except Exception as e:
self.logger.debug(f"[{i+1}/90] Connection failed: {repr(e)}")
time.sleep(1)
self.cleanup()
if self.is_error:
raise RuntimeError("Failed to start vLLM server. Please check the logs for more information.")
else:
raise RuntimeError("Failed to start vLLM server within timeout. You can try increase server_start_timeout argument.")
def format_response(self, response: dict) -> str:
# check if content is formatted like <think>...</think>...<answer>...</answer>
content = response['choices'][0]['message']['content']
if re.search(r'<think>.*</think>.*<answer>.*</answer>', content):
return content
try:
reasoning_content = response['choices'][0]["message"]["reasoning_content"]
except:
reasoning_content = ""
if reasoning_content != "":
return f"<think>{reasoning_content}</think>\n<answer>{content}</answer>"
else:
return content
def _api_chat_with_id(self, id, payload, model):
try:
payload = {
"model": self.hf_model_name_or_path,
"messages": payload,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens
}
response = requests.post(f"http://{self.host}:{self.port}/v1/chat/completions", json=payload)
if response.status_code == 200:
# logging.info(f"API request successful")
response_data = response.json()
# logging.info(f"API response: {response_data['choices'][0]['message']['content']}")
return id,self.format_response(response_data)
else:
self.logger.error(f"API request failed with status {response.status_code}: {response.text}")
return id, None
except Exception as e:
self.logger.error(f"API request error: {e}")
return id, None
def generate_from_input(self,
user_inputs: list[str], system_prompt: str = "You are a helpful assistant"
) -> list[str]:
if not self.backend_initialized:
self.start_serving()
responses = [None] * len(user_inputs)
# 使用 ThreadPoolExecutor 并行处理多个问题
# logging.info(f"Generating {len(questions)} responses")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(
self._api_chat_with_id,
payload = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
],
model = self.hf_model_name_or_path,
id = idx
) for idx, question in enumerate(user_inputs)
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Generating......"):
response = future.result() # (id, response)
responses[response[0]] = response[1]
return responses
# def generate_from_input(self, user_inputs: list[str], system_prompt: str = "You are a helpful assistant") -> list[str]:
# if not self.backend_initialized:
# self.start_serving()
# messages = [{"role": "system", "content": system_prompt}] + \
# [{"role": "user", "content": q} for q in user_inputs]
# payload = {
# "model": self.hf_model_name_or_path,
# "messages": messages,
# "temperature": 0.7,
# "top_p": 0.9,
# "max_tokens": 1024
# }
# response = requests.post(f"http://{self.host}:{self.port}/v1/chat/completions", json=payload)
# response.raise_for_status()
# data = response.json()
# return [choice["message"]["content"] for choice in data["choices"]]
def cleanup(self):
if self.process:
self.logger.info("Shutting down vLLM subprocess...")
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
self.process.wait()
time.sleep(10)
self.logger.success("vLLM subprocess terminated.")
self.backend_initialized = False
\ No newline at end of file
import os
import torch
from dataflow import get_logger
from huggingface_hub import snapshot_download
from dataflow.core import LLMServingABC
from transformers import AutoProcessor
from typing import Optional, Union, List, Dict, Any, Tuple
import librosa
import requests
import numpy as np
from io import BytesIO
# 不重采样
DEFAULT_SR = None
def _read_audio_remote(path: str, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]:
url = path
resp = requests.get(url, stream=True)
audio_bytes = BytesIO(resp.content)
y, sr = librosa.load(audio_bytes, sr=sr)
return y, sr
def _read_audio_local(path: str, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]:
return librosa.load(path, sr=sr, mono=True)
def _read_audio_bytes(data: bytes, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]:
return librosa.load(BytesIO(data), sr=sr, mono=True)
def _read_audio_base64(b64: str, sr: Optional[int] = DEFAULT_SR) -> Tuple[np.ndarray, int]:
header, b64data = b64.split(",", 1)
data = base64.b64decode(b64data)
return _read_audio_bytes(data, sr=sr)
def process_audio_info(
conversations: List[dict] | List[List[dict]], # 这个conversation对应的是vllm中的messages列表(对应的是conversation_to_message函数的message)
sampling_rate: Optional[int]
) -> Tuple[
Optional[List[np.ndarray]],
Optional[List[int]],
Optional[List[str]]
]:
"""
类似于 vision 的 process_vision_info,从 message 列表中提取音频输入。
支持三种格式输入:
- 本地或 http(s) URL 路径(通过 librosa 接口处理)
- base64 编码 (data:audio/…;base64,…)
- 直接传入 bytes 对象
返回二元组:
- audio_arrays: 解码后的 waveform (List[np.ndarray])
- sample_rates: 采样率列表 (List[int])
"""
if isinstance(conversations, list) and conversations and isinstance(conversations[0], dict):
# 单条 conversaion
conversations = [conversations] # conversations被统一为List[List[dict]]
audio_arrays = []
sampling_rates = []
for conv in conversations:
for msg in conv:
if not isinstance(msg.get("content"), list):
continue
for ele in msg["content"]:
if ele.get("type") != "audio":
continue
aud = ele.get("audio")
if isinstance(aud, str):
if aud.startswith("data:audio") and "base64," in aud:
arr, sr = _read_audio_base64(aud, sr=sampling_rate)
audio_arrays.append(arr)
sampling_rates.append(sr)
elif aud.startswith("http://") or aud.startswith("https://"):
# 使用 librosa 支持远程路径
arr, sr = _read_audio_remote(aud, sr=sampling_rate)
audio_arrays.append(arr)
sampling_rates.append(sr)
else:
# 本地路径
arr, sr = _read_audio_local(aud, sr=sampling_rate)
audio_arrays.append(arr)
sampling_rates.append(sr)
elif isinstance(aud, (bytes, bytearray)):
arr, sr = _read_audio_bytes(bytes(aud), sr=sampling_rate)
audio_arrays.append(arr)
sampling_rates.append(sr)
else:
raise ValueError(f"Unsupported audio type: {type(aud)}")
if not audio_arrays:
return None, None
return audio_arrays, sampling_rates
class LocalModelLALMServing_vllm(LLMServingABC):
'''
A class for generating text using vllm, with model from huggingface or local directory
'''
def __init__(self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.7,
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_top_k: int = 40,
vllm_repetition_penalty: float = 1.0,
vllm_seed: int = 42,
vllm_max_model_len: int = None,
vllm_gpu_memory_utilization: float=0.9,
):
self.logger = get_logger()
self.load_model(
hf_model_name_or_path=hf_model_name_or_path,
hf_cache_dir=hf_cache_dir,
hf_local_dir=hf_local_dir,
vllm_tensor_parallel_size=vllm_tensor_parallel_size,
vllm_temperature=vllm_temperature,
vllm_top_p=vllm_top_p,
vllm_max_tokens=vllm_max_tokens,
vllm_top_k=vllm_top_k,
vllm_repetition_penalty=vllm_repetition_penalty,
vllm_seed=vllm_seed,
vllm_max_model_len=vllm_max_model_len,
vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
)
self.backend_initialized = False
def load_model(self,
hf_model_name_or_path: str = None,
hf_cache_dir: str = None,
hf_local_dir: str = None,
vllm_tensor_parallel_size: int = 1,
vllm_temperature: float = 0.7,
vllm_top_p: float = 0.9,
vllm_max_tokens: int = 1024,
vllm_top_k: int = 40,
vllm_repetition_penalty: float = 1.0,
vllm_seed: int = 42,
vllm_max_model_len: int = None,
vllm_gpu_memory_utilization: float=0.9,
):
self.hf_model_name_or_path = hf_model_name_or_path
self.hf_cache_dir = hf_cache_dir
self.hf_local_dir = hf_local_dir
self.vllm_tensor_parallel_size = vllm_tensor_parallel_size
self.vllm_temperature = vllm_temperature
self.vllm_top_p = vllm_top_p
self.vllm_max_tokens = vllm_max_tokens
self.vllm_top_k = vllm_top_k
self.vllm_repetition_penalty = vllm_repetition_penalty
self.vllm_seed = vllm_seed
self.vllm_max_model_len = vllm_max_model_len
self.vllm_gpu_memory_utilization = vllm_gpu_memory_utilization
def start_serving(self):
self.backend_initialized = True
self.logger = get_logger()
if self.hf_model_name_or_path is None:
raise ValueError("hf_model_name_or_path is required")
elif os.path.exists(self.hf_model_name_or_path):
self.logger.info(f"Using local model path: {self.hf_model_name_or_path}")
self.real_model_path = self.hf_model_name_or_path
else:
self.logger.info(f"Downloading model from HuggingFace: {self.hf_model_name_or_path}")
self.real_model_path = snapshot_download(
repo_id=self.hf_model_name_or_path,
cache_dir=self.hf_cache_dir,
local_dir=self.hf_local_dir,
)
# Import vLLM and set up the environment for multiprocessing
# vLLM requires the multiprocessing method to be set to spawn
try:
from vllm import LLM,SamplingParams
except:
raise ImportError("please install vllm first like 'pip install open-dataflow[vllm]'")
# Set the environment variable for vllm to use spawn method for multiprocessing
# See https://docs.vllm.ai/en/v0.7.1/design/multiprocessing.html
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "spawn"
self.sampling_params = SamplingParams(
temperature=self.vllm_temperature,
top_p=self.vllm_top_p,
max_tokens=self.vllm_max_tokens,
top_k=self.vllm_top_k,
repetition_penalty=self.vllm_repetition_penalty,
seed=self.vllm_seed
)
self.llm = LLM(
model=self.real_model_path,
tensor_parallel_size=self.vllm_tensor_parallel_size,
max_model_len=self.vllm_max_model_len,
gpu_memory_utilization=self.vllm_gpu_memory_utilization,
)
self.processor = AutoProcessor.from_pretrained(self.real_model_path, cache_dir=self.hf_cache_dir)
self.logger.success(f"Model loaded from {self.real_model_path} by vLLM backend")
def generate_from_input(self,
user_inputs: list[str],
system_prompt: str = "You are a helpful assistant",
) -> list[str]:
if not self.backend_initialized:
self.start_serving()
messages = []
for path_or_url in user_inputs:
message = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": "请帮我把这段音频的文字翻译成中文"},
{"type": "audio", "audio": path_or_url}
]
}
]
messages.append(message)
user_inputs = [self.processor.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=True,
add_audio_id = True
) for msg in messages]
audio_arrays, sampling_rates = process_audio_info(conversations=messages, sampling_rate=16000)
audio_inputs = [(audio_array, sampling_rate) for audio_array, sampling_rate in zip(audio_arrays, sampling_rates)]
prompts = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True, # Set to False to strictly disable thinking
)
mm_entries = [
{
"prompt": prompt,
"multi_modal_data": {"audio": (audio_array, sampling_rate)}
}
for prompt, audio_array, sampling_rate in zip(prompts, audio_arrays, sampling_rates)
]
responses = self.llm.generate(mm_entries, self.sampling_params)
return [output.outputs[0].text for output in responses]
def cleanup(self):
del self.llm
import gc;
gc.collect()
torch.cuda.empty_cache()
\ No newline at end of file
# Dataflow pipelines with API keys
- Note that you have to export your api key to your environment variables before running the code:
```shell
export DF_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
```
- Then you can modify the settings of the files in each pipeline files. Especially the parameters for `llm_serving` and `storage`
```python
# global storage class
self.storage = FileStorage(
first_entry_file_name="../example/ReasoningPipeline/pipeline_math_short.json", # path to the first entry file
cache_path="./cache_local", # path to store middle results
file_name_prefix="dataflow_cache_step", # prefix of the cache file name
cache_type="jsonl", # type of the cache file
)
# use API server as LLM serving
llm_serving = APILLMServing_request(
api_url="https://api.openai.com/v1/chat/completions", # url of the API server
model_name="gpt-4o",
max_workers=100
)
```
- Then you can run the code:
```shell
python reasoning_pipeline.py
```
import pandas as pd
from dataflow.operators.agentic_rag import AgenticRAGQAF1SampleEvaluator
from dataflow.operators.agentic_rag import (
AgenticRAGAtomicTaskGenerator,
AgenticRAGDepthQAGenerator,
AgenticRAGWidthQAGenerator
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import APILLMServing_request
from dataflow.core import LLMServingABC
class AgenticRAGEval_APIPipeline():
def __init__(self, llm_serving=None):
self.storage = FileStorage(
first_entry_file_name="../example_data/AgenticRAGPipeline/eval_test_data.jsonl",
cache_path="./agenticRAG_eval_cache",
file_name_prefix="agentic_rag_eval",
cache_type="jsonl",
)
self.llm_serving = APILLMServing_request(
api_url="https://api.openai.com/v1/chat/completions",
model_name="gpt-4o-mini",
max_workers=500
)
self.task_step1 = AgenticRAGAtomicTaskGenerator(
llm_serving=self.llm_serving
)
self.task_step2 = AgenticRAGQAF1SampleEvaluator()
def forward(self):
self.task_step1.run(
storage = self.storage.step(),
input_key = "contents",
)
self.task_step2.run(
storage=self.storage.step(),
output_key="F1Score",
input_prediction_key="refined_answer",
input_ground_truth_key="golden_doc_answer"
)
if __name__ == "__main__":
model = AgenticRAGEval_APIPipeline()
model.forward()
from dataflow.operators.chemistry import ExtractSmilesFromTextGenerator
from dataflow.operators.chemistry import SmilesEquivalenceDatasetEvaluator
from dataflow.serving import APILLMServing_request
from dataflow.utils.storage import FileStorage
from dataflow.prompts.chemistry import ExtractSmilesFromTextPrompt
smiles_prompt = """Extract the monomer/small molecule information from the text and format it as a structured JSON object.
Follow these rules strictly:
1. For each monomer/small molecule, extract:
- abbreviation: The commonly used abbreviated name
- full_name: The complete chemical name
- smiles: The SMILES notation of the molecular structure
2. General rules:
- Each monomer/small molecule should have a unique abbreviation
- If a monomer's information is incomplete, include only the available information
- Don't recognize polymer which have "poly" in the name as monomer
Example output:
[
{
"abbreviation": "4-ODA",
"full_name": "4,4′-Oxydianiline",
"smiles": "O(c1ccc(N)cc1)c2ccc(cc2)N"
},
{
"abbreviation": "6FDA",
"full_name": "4,4'-(hexafluoroisopropylidene)diphthalic anhydride",
"smiles": "C1=CC2=C(C=C1C(C3=CC4=C(C=C3)C(=O)OC4=O)(C(F)(F)F)C(F)(F)F)C(=O)OC2=O"
}
]
Please make sure to output pure json which can be saved into a json file, do not output like html.
"""
response_format = {
"type": "json_schema",
"json_schema": {
"name": "chemical_structures_response",
"strict": True,
"schema": {
"type": "object",
"properties": {
"chemical_structures": {
"type": "array",
"items": {
"type": "object",
"properties": {
"abbreviation": {
"type": "string"
},
"full_name": {
"type": "string"
},
"smiles": {
"type": "string"
}
},
"required": ["abbreviation", "full_name", "smiles"],
"additionalProperties": False
}
}
},
"required": ["chemical_structures"],
"additionalProperties": False
}
}
}
class ExtractSmiles():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/chemistry/matched_sample_10.json",
#first_entry_file_name="/Users/lianghao/Desktop/dataflow_code/test_0901/example_data/chemistry/matched_sample_10.json",
cache_path="./cache_all_17_24_gpt_5",
file_name_prefix="math_QA",
cache_type="json",
)
self.model_cache_dir = './dataflow_cache'
self.llm_serving = APILLMServing_request(
api_url="https://api.openai.com/v1/chat/completions",
model_name="gemini-2.5-flash",
max_workers=200,
)
self.prompt_smile_extractor = ExtractSmilesFromTextGenerator(
llm_serving = self.llm_serving,
prompt_template=ExtractSmilesFromTextPrompt(smiles_prompt),
)
self.smile_eval = SmilesEquivalenceDatasetEvaluator()
def forward(self):
# Initial filters
self.prompt_smile_extractor.run(
storage = self.storage.step(),
input_content_key = "text",
input_abbreviation_key = "abbreviations",
output_key = "synth_smiles"
)
self.smile_eval.run(
storage = self.storage.step(),
)
if __name__ == "__main__":
# This is the entry point for the pipeline
model = ExtractSmiles()
model.forward()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment