"vscode:/vscode.git/clone" did not exist on "d81edded69a5534a80785b68cde26c547cfcd4c6"
Unverified Commit 26d34eb6 authored by Reid's avatar Reid Committed by GitHub
Browse files

refactor example - qwen3_reranker (#19847)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent 53da4cd3
...@@ -22,7 +22,10 @@ model_name = "Qwen/Qwen3-Reranker-0.6B" ...@@ -22,7 +22,10 @@ model_name = "Qwen/Qwen3-Reranker-0.6B"
# If you want to load the official original version, the init parameters are # If you want to load the official original version, the init parameters are
# as follows. # as follows.
model = LLM(
def get_model() -> LLM:
"""Initializes and returns the LLM model for Qwen3-Reranker."""
return LLM(
model=model_name, model=model_name,
task="score", task="score",
hf_overrides={ hf_overrides={
...@@ -30,7 +33,8 @@ model = LLM( ...@@ -30,7 +33,8 @@ model = LLM(
"classifier_from_token": ["no", "yes"], "classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True, "is_original_qwen3_reranker": True,
}, },
) )
# Why do we need hf_overrides for the official original version: # Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for # vllm converts it to Qwen3ForSequenceClassification when loaded for
...@@ -51,7 +55,8 @@ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" ...@@ -51,7 +55,8 @@ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n" query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}" document_template = "<Document>: {doc}{suffix}"
if __name__ == "__main__":
def main() -> None:
instruction = ( instruction = (
"Given a web search query, retrieve relevant passages that answer the query" "Given a web search query, retrieve relevant passages that answer the query"
) )
...@@ -72,6 +77,13 @@ if __name__ == "__main__": ...@@ -72,6 +77,13 @@ if __name__ == "__main__":
] ]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
model = get_model()
outputs = model.score(queries, documents) outputs = model.score(queries, documents)
print("-" * 30)
print([output.outputs.score for output in outputs]) print([output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment