sequence_reward_online.py 1.99 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
Example online usage of sequence reward models.
5

6
Run `vllm serve <model> --runner pooling`
7
8
to start up the server in vLLM. e.g.

9
10
11
12
13
14
vllm serve Skywork/Skywork-Reward-V2-Qwen3-0.6B

The key distinction between sequence classification and token classification
lies in their output granularity: sequence classification produces a single
result for an entire input sequence, whereas token classification yields a
result for each individual token within the sequence.
15
"""
16

17
18
19
20
21
22
23
24
25
26
27
28
import argparse
import pprint

import requests


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return response


29
def parse_args():
30
31
32
33
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)

34
35
36
37
    return parser.parse_args()


def main(args):
38
39
40
41
42
43
    base_url = f"http://{args.host}:{args.port}"
    models_url = base_url + "/v1/models"
    pooing_url = base_url + "/pooling"

    response = requests.get(models_url)
    model = response.json()["data"][0]["id"]
44
45

    # Input like Completions API
46
47
    prompt = {"model": model, "input": "vLLM is great!"}
    pooling_response = post_http_request(prompt=prompt, api_url=pooing_url)
48
    print("-" * 50)
49
50
    print("Pooling Response:")
    pprint.pprint(pooling_response.json())
51
    print("-" * 50)
52
53
54

    # Input like Chat API
    prompt = {
55
        "model": model,
56
57
58
59
60
61
        "messages": [
            {
                "role": "user",
                "content": [{"type": "text", "text": "vLLM is great!"}],
            }
        ],
62
    }
63
    pooling_response = post_http_request(prompt=prompt, api_url=pooing_url)
64
65
    print("Pooling Response:")
    pprint.pprint(pooling_response.json())
66
67
68
69
70
71
    print("-" * 50)


if __name__ == "__main__":
    args = parse_args()
    main(args)