test_fp32_lm_head.py 3.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
import unittest
from types import SimpleNamespace
from unittest.mock import patch

import torch
import torch.nn as nn
import torch.nn.functional as F

from sglang.srt.layers.logits_processor import LogitsProcessor
10
11
12
13
14
from sglang.srt.server_args import (
    ServerArgs,
    get_global_server_args,
    set_global_server_args_for_scheduler,
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class LMHeadStub(nn.Module):
    def __init__(self, vocab, hidden, dtype, device="cuda"):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(vocab, hidden, dtype=dtype, device=device)
        )


class DummyMeta:
    gathered_buffer = None
    next_token_logits_buffer = None

    def compute_dp_attention_metadata(self): ...


class TestLMHeadFP32(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("needs CUDA GPU")

    def _make_logprocessor(self, vocab_size, enable_fp32):
39
40
41
42
        ServerArgs.__post_init__ = lambda self: None  # disable validation
        set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
        get_global_server_args().enable_dp_lm_head = False
        get_global_server_args().enable_fp32_lm_head = enable_fp32
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
        return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)

    def _run_case(
        self,
        hidden_state_dtype,
        enable_fp32,
        weights_dtype,
        expected_a_dtype,
        expected_b_dtype,
    ):
        device = "cuda"
        BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128
        hidden_state = torch.randn(
            BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device
        )
        head = LMHeadStub(VOCAB_SIZE, HIDDEN_SIZE, dtype=weights_dtype, device=device)
        meta = DummyMeta()
        logprocessor = self._make_logprocessor(VOCAB_SIZE, enable_fp32)

        original_matmul = torch.matmul
        original_linear = F.linear

        state = {
            "called": False,  # Whether a matmul/linear call has been intercepted yet
            "operation": None,  # Which operation was captured ("matmul" or "linear")
            "a": None,  # The dtype of the first input tensor to the operation
            "b": None,  # The dtype of the second input tensor to the operation
        }

        def probe_matmul(a, b, *args, **kw):
            if not state["called"]:
                state.update(called=True, operation="matmul", a=a.dtype, b=b.dtype)
            return original_matmul(a, b, *args, **kw)

        def probe_linear(x, w, bias=None):
            if not state["called"]:
                state.update(called=True, ooperationp="linear", a=x.dtype, b=w.dtype)
            return original_linear(x, w, bias)

        with patch("torch.matmul", new=probe_matmul), patch(
            "torch.nn.functional.linear", new=probe_linear
        ):
            logits = logprocessor._get_logits(hidden_state, head, meta)
        self.assertEqual(hidden_state.dtype, hidden_state_dtype)
        self.assertTrue(state["called"], "no call lm head matlmul/linear")
        self.assertEqual(state["a"], expected_a_dtype)
        self.assertEqual(state["b"], expected_b_dtype)

    def test_flag_true_fp16_activations(self):
        self._run_case(torch.float16, True, torch.float16, torch.float32, torch.float32)

    def test_flag_true_bf16_activations(self):
        self._run_case(
            torch.bfloat16, True, torch.bfloat16, torch.float32, torch.float32
        )

    def test_flag_false_fp16_path(self):
        self._run_case(
            torch.float16, False, torch.float16, torch.float16, torch.float16
        )

    def test_flag_false_bf16_path(self):
        self._run_case(
            torch.bfloat16, False, torch.bfloat16, torch.bfloat16, torch.bfloat16
        )


if __name__ == "__main__":
    unittest.main(verbosity=2)