test_utils_update_weights.py 5.5 KB
Newer Older
1
2
import asyncio
import os
3
import unittest
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM

from sglang.srt.entrypoints.engine import Engine
from sglang.srt.weight_sync.utils import update_weights
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST


class AsyncEngine(Engine):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def update_weights_from_tensor(self, update_weights_request):
        return await self.tokenizer_manager.update_weights_from_tensor(
            update_weights_request, None
        )


def is_distributed_available():
    """Check if distributed training environment is available"""
    required_vars = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
    return all(var in os.environ for var in required_vars)


def setup_single_process_distributed():
    """Setup distributed environment for single process testing"""
    if not is_distributed_available():
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12356"
        os.environ["LOCAL_RANK"] = "0"


41
class TestUtilsUpdateWeights(unittest.TestCase):
42
43
    """Test class for utils.update_weights function"""

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    @classmethod
    def setUpClass(cls):
        """Setup distributed environment and test fixtures for the entire test class"""
        cls.setup_distributed()
        cls.setup_test_engine()
        cls.setup_test_model()
        cls.setup_device_mesh()

    @classmethod
    def tearDownClass(cls):
        """Cleanup after all tests"""
        if hasattr(cls, "engine") and cls.engine:
            cls.engine.shutdown()

        # Cleanup distributed
        if dist.is_initialized():
            dist.destroy_process_group()

    @classmethod
    def setup_distributed(cls):
64
65
66
67
68
69
70
71
72
        """Setup distributed environment for testing"""
        setup_single_process_distributed()

        if not dist.is_initialized():
            try:
                dist.init_process_group(
                    backend="nccl" if torch.cuda.is_available() else "gloo"
                )
            except Exception as e:
73
74
75
                raise unittest.SkipTest(
                    f"Could not initialize distributed backend: {e}"
                )
76

77
78
        cls.rank = dist.get_rank()
        cls.world_size = dist.get_world_size()
79
80

        if torch.cuda.is_available():
81
            torch.cuda.set_device(cls.rank % torch.cuda.device_count())
82
83
84
85
86
87
88

        # Set up environment variables
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
        os.environ["NCCL_CUMEM_ENABLE"] = "0"
        os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
        os.environ["CUDA_MODULE_LOADING"] = "AUTO"

89
90
    @classmethod
    def setup_test_engine(cls):
91
        """Setup test engine"""
92
93
        if cls.rank == 0:
            cls.engine = AsyncEngine(
94
95
96
97
                model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                dtype="bfloat16",
                mem_fraction_static=0.3,
                enable_memory_saver=True,
98
99
                tp_size=cls.world_size,
                disable_cuda_graph=False,
100
101
            )
        else:
102
            cls.engine = None
103

104
105
    @classmethod
    def setup_test_model(cls):
106
107
        """Load test model"""
        try:
108
            cls.model = AutoModelForCausalLM.from_pretrained(
109
110
111
112
113
114
115
116
117
                DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                device_map="cpu",
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                torch_dtype=(
                    torch.float16 if torch.cuda.is_available() else torch.float32
                ),
            )
        except Exception as e:
118
            raise unittest.SkipTest(f"Could not load test model: {e}")
119

120
121
    @classmethod
    def setup_device_mesh(cls):
122
123
        """Create device mesh for testing"""
        if not torch.cuda.is_available():
124
            raise unittest.SkipTest("CUDA not available for device mesh")
125

126
127
128
        cls.device_mesh_key = "tp"
        cls.mesh = init_device_mesh(
            "cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        )

    def create_test_params_batch(self, model, num_params=64):
        """Create a batch of test parameters from the model"""
        param_names = []
        test_tensors = []

        # Get first few parameters from the model for testing
        for i, (name, tensor) in enumerate(model.named_parameters()):
            if i >= num_params:
                break
            param_names.append(name)
            # Create test tensor with known values, matching original shape and dtype
            test_tensor = torch.full_like(tensor, 1.5, dtype=tensor.dtype).cuda()
            test_tensors.append(test_tensor)

        return list(zip(param_names, test_tensors))

147
    def test_utils_update_weights(self):
148
149
        """Test basic functionality of utils.update_weights"""

150
151
152
153
154
155
156
157
158
159
160
161
        async def async_test():
            # Create test parameters batch
            params_batch = self.create_test_params_batch(self.model, num_params=2)

            # Test the utils.update_weights function
            result = await update_weights(
                engine=self.engine,
                params_batch=params_batch,
                device_mesh_key=self.device_mesh_key,
                device_mesh=self.mesh,
                load_format=None,
            )
162

163
            self.assertIn("Success", result)
164

165
166
        # Run the async test
        asyncio.run(async_test())
167
168
169


if __name__ == "__main__":
170
    unittest.main()