"docs/zh_cn/2_new_data_model.md" did not exist on "45e70078d2c42217c631f3b0b155fe808c548304"
test_scheduler.py 3.94 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Dict

import torch

from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.scheduler import Scheduler


class TestScheduler:

    def setup_method(self, method):
        self.max_batch_size = 4
        self.scheduler = Scheduler(max_batch_size=self.max_batch_size)
        assert (
            len(self.scheduler.active_request_pool) == 0
        ), "Active request pool should be empty on initalization"
        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), "Waiting request pool should be empty on initalization"
        assert (
            len(self.scheduler.completed_request_pool) == 0
        ), "Completed request pool should be empty on initalization"

    def test_scheduler(self):
        prompt = "sample prompt"
        prompt_tokens = torch.randn(5)
        inference_parameters = CommonInferenceParams()

        for i in range(self.max_batch_size):
            self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
            assert (
                len(self.scheduler.active_request_pool) == i + 1
            ), f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}"

        self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
        assert (
            len(self.scheduler.waiting_request_pool) == 1
        ), f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests"

        waiting_request: InferenceRequest = list(self.scheduler.waiting_request_pool.values())[0]
        assert (
            waiting_request.status == Status.WAITING_IN_QUEUE
        ), f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request"

        assert (
            self.scheduler.have_requests_pending()
        ), "Scheduler should have requests pending, but it seems to be having no requests"

        active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool
        for request_id, request in active_request_dict.items():
            # Mark every even request compelted
            if int(request_id) % 2 == 0:
                request.status = Status.COMPLETED

        self.scheduler.update_requests_pools(active_request_dict)
        assert (
            len(self.scheduler.active_request_pool) == 3
        ), f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}"

        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"

        assert (
            len(self.scheduler.completed_request_pool) == 2
        ), f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests "

        active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool
        for request_id, request in active_request_dict.items():
            # Mark all requests compelted
            request.status = Status.COMPLETED

        self.scheduler.update_requests_pools(active_request_dict)
        assert (
            len(self.scheduler.active_request_pool) == 0
        ), f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}"

        assert (
            len(self.scheduler.waiting_request_pool) == 0
        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"

        assert (
            len(self.scheduler.completed_request_pool) == 5
        ), f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests "

        assert (
            self.scheduler.have_requests_pending() == False
        ), "Scheduler should not have any requests pending"