test_scheduler.py 4.32 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
from typing import Dict

import torch

from megatron.core.inference.inference_request import InferenceRequest, Status
xingjinliang's avatar
xingjinliang committed
6
from megatron.core.inference.sampling_params import SamplingParams
xingjinliang's avatar
xingjinliang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from megatron.core.inference.scheduler import Scheduler


class TestScheduler:

    def setup_method(self, method):
        self.max_batch_size = 4
        self.scheduler = Scheduler(max_batch_size=self.max_batch_size)
        assert (
            len(self.scheduler.active_request_pool) == 0
        ), "Active request pool should be empty on initalization"
        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), "Waiting request pool should be empty on initalization"
        assert (
            len(self.scheduler.completed_request_pool) == 0
        ), "Completed request pool should be empty on initalization"

    def test_scheduler(self):
        prompt = "sample prompt"
        prompt_tokens = torch.randn(5)
xingjinliang's avatar
xingjinliang committed
28
        inference_parameters = SamplingParams()
xingjinliang's avatar
xingjinliang committed
29

silencealiang's avatar
silencealiang committed
30
        active_request_ids = []
xingjinliang's avatar
xingjinliang committed
31
        for i in range(self.max_batch_size):
silencealiang's avatar
silencealiang committed
32
            request_id = self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
xingjinliang's avatar
xingjinliang committed
33
34
35
            assert (
                len(self.scheduler.active_request_pool) == i + 1
            ), f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}"
silencealiang's avatar
silencealiang committed
36
            active_request_ids.append(request_id)
xingjinliang's avatar
xingjinliang committed
37

silencealiang's avatar
silencealiang committed
38
        request_id = self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
xingjinliang's avatar
xingjinliang committed
39
40
41
42
43
44
45
46
        assert (
            len(self.scheduler.waiting_request_pool) == 1
        ), f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests"

        waiting_request: InferenceRequest = list(self.scheduler.waiting_request_pool.values())[0]
        assert (
            waiting_request.status == Status.WAITING_IN_QUEUE
        ), f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request"
silencealiang's avatar
silencealiang committed
47
48
49
        assert (
            request_id == waiting_request.request_id
        ), f"Waiting request request ID should match returned request ID"
xingjinliang's avatar
xingjinliang committed
50
51
52
53
54

        assert (
            self.scheduler.have_requests_pending()
        ), "Scheduler should have requests pending, but it seems to be having no requests"

silencealiang's avatar
silencealiang committed
55
56
57
58
        active_request_dict: Dict[str, InferenceRequest] = self.scheduler.active_request_pool
        assert set(active_request_dict.keys()) == set(
            active_request_ids
        ), f"Active request pool IDs should match returned request IDs"
xingjinliang's avatar
xingjinliang committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        for request_id, request in active_request_dict.items():
            # Mark every even request compelted
            if int(request_id) % 2 == 0:
                request.status = Status.COMPLETED

        self.scheduler.update_requests_pools(active_request_dict)
        assert (
            len(self.scheduler.active_request_pool) == 3
        ), f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}"

        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"

        assert (
            len(self.scheduler.completed_request_pool) == 2
        ), f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests "

silencealiang's avatar
silencealiang committed
77
        active_request_dict: Dict[str, InferenceRequest] = self.scheduler.active_request_pool
xingjinliang's avatar
xingjinliang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        for request_id, request in active_request_dict.items():
            # Mark all requests compelted
            request.status = Status.COMPLETED

        self.scheduler.update_requests_pools(active_request_dict)
        assert (
            len(self.scheduler.active_request_pool) == 0
        ), f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}"

        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"

        assert (
            len(self.scheduler.completed_request_pool) == 5
        ), f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests "

        assert (
            self.scheduler.have_requests_pending() == False
        ), "Scheduler should not have any requests pending"