test_scheduler.py 3.92 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
from typing import Dict

import torch

from megatron.core.inference.inference_request import InferenceRequest, Status
silencealiang's avatar
add  
silencealiang committed
6
from megatron.core.inference.sampling_params import SamplingParams
xingjinliang's avatar
xingjinliang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from megatron.core.inference.scheduler import Scheduler


class TestScheduler:

    def setup_method(self, method):
        self.max_batch_size = 4
        self.scheduler = Scheduler(max_batch_size=self.max_batch_size)
        assert (
            len(self.scheduler.active_request_pool) == 0
        ), "Active request pool should be empty on initalization"
        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), "Waiting request pool should be empty on initalization"
        assert (
            len(self.scheduler.completed_request_pool) == 0
        ), "Completed request pool should be empty on initalization"

    def test_scheduler(self):
        prompt = "sample prompt"
        prompt_tokens = torch.randn(5)
silencealiang's avatar
add  
silencealiang committed
28
        inference_parameters = SamplingParams()
xingjinliang's avatar
xingjinliang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

        for i in range(self.max_batch_size):
            self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
            assert (
                len(self.scheduler.active_request_pool) == i + 1
            ), f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}"

        self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
        assert (
            len(self.scheduler.waiting_request_pool) == 1
        ), f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests"

        waiting_request: InferenceRequest = list(self.scheduler.waiting_request_pool.values())[0]
        assert (
            waiting_request.status == Status.WAITING_IN_QUEUE
        ), f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request"

        assert (
            self.scheduler.have_requests_pending()
        ), "Scheduler should have requests pending, but it seems to be having no requests"

        active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool
        for request_id, request in active_request_dict.items():
            # Mark every even request compelted
            if int(request_id) % 2 == 0:
                request.status = Status.COMPLETED

        self.scheduler.update_requests_pools(active_request_dict)
        assert (
            len(self.scheduler.active_request_pool) == 3
        ), f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}"

        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"

        assert (
            len(self.scheduler.completed_request_pool) == 2
        ), f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests "

        active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool
        for request_id, request in active_request_dict.items():
            # Mark all requests compelted
            request.status = Status.COMPLETED

        self.scheduler.update_requests_pools(active_request_dict)
        assert (
            len(self.scheduler.active_request_pool) == 0
        ), f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}"

        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"

        assert (
            len(self.scheduler.completed_request_pool) == 5
        ), f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests "

        assert (
            self.scheduler.have_requests_pending() == False
        ), "Scheduler should not have any requests pending"