"megatron/inference/text_generation/generation.py" did not exist on "6c40f8922abf1259cae2f6035949034a420d00fe"
test_utils_update_weights.py 5.56 KB
Newer Older
1
2
import asyncio
import os
3
import unittest
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM

from sglang.srt.entrypoints.engine import Engine
from sglang.srt.weight_sync.utils import update_weights
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST


class AsyncEngine(Engine):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def update_weights_from_tensor(self, update_weights_request):
        return await self.tokenizer_manager.update_weights_from_tensor(
            update_weights_request, None
        )


def is_distributed_available():
    """Check if distributed training environment is available"""
    required_vars = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
    return all(var in os.environ for var in required_vars)


def setup_single_process_distributed():
    """Setup distributed environment for single process testing"""
    if not is_distributed_available():
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12356"
        os.environ["LOCAL_RANK"] = "0"


41
class TestUtilsUpdateWeights(unittest.TestCase):
42
43
    """Test class for utils.update_weights function"""

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    @classmethod
    def setUpClass(cls):
        """Setup distributed environment and test fixtures for the entire test class"""
        cls.setup_distributed()
        cls.setup_test_engine()
        cls.setup_test_model()
        cls.setup_device_mesh()

    @classmethod
    def tearDownClass(cls):
        """Cleanup after all tests"""
        if hasattr(cls, "engine") and cls.engine:
            cls.engine.shutdown()

        # Cleanup distributed
        if dist.is_initialized():
            dist.destroy_process_group()

    @classmethod
    def setup_distributed(cls):
64
65
66
67
68
69
70
71
72
        """Setup distributed environment for testing"""
        setup_single_process_distributed()

        if not dist.is_initialized():
            try:
                dist.init_process_group(
                    backend="nccl" if torch.cuda.is_available() else "gloo"
                )
            except Exception as e:
73
74
75
                raise unittest.SkipTest(
                    f"Could not initialize distributed backend: {e}"
                )
76

77
78
        cls.rank = dist.get_rank()
        cls.world_size = dist.get_world_size()
79
80

        if torch.cuda.is_available():
81
            torch.cuda.set_device(cls.rank % torch.cuda.device_count())
82
83
84
85
86
87
88
89

        # Set up environment variables
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
        os.environ["NCCL_CUMEM_ENABLE"] = "0"
        os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
        os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
        os.environ["CUDA_MODULE_LOADING"] = "AUTO"

90
91
    @classmethod
    def setup_test_engine(cls):
92
        """Setup test engine"""
93
94
        if cls.rank == 0:
            cls.engine = AsyncEngine(
95
96
97
98
                model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                dtype="bfloat16",
                mem_fraction_static=0.3,
                enable_memory_saver=True,
99
100
                tp_size=cls.world_size,
                disable_cuda_graph=False,
101
102
            )
        else:
103
            cls.engine = None
104

105
106
    @classmethod
    def setup_test_model(cls):
107
108
        """Load test model"""
        try:
109
            cls.model = AutoModelForCausalLM.from_pretrained(
110
111
112
113
114
115
116
117
118
                DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                device_map="cpu",
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                torch_dtype=(
                    torch.float16 if torch.cuda.is_available() else torch.float32
                ),
            )
        except Exception as e:
119
            raise unittest.SkipTest(f"Could not load test model: {e}")
120

121
122
    @classmethod
    def setup_device_mesh(cls):
123
124
        """Create device mesh for testing"""
        if not torch.cuda.is_available():
125
            raise unittest.SkipTest("CUDA not available for device mesh")
126

127
128
129
        cls.device_mesh_key = "tp"
        cls.mesh = init_device_mesh(
            "cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        )

    def create_test_params_batch(self, model, num_params=64):
        """Create a batch of test parameters from the model"""
        param_names = []
        test_tensors = []

        # Get first few parameters from the model for testing
        for i, (name, tensor) in enumerate(model.named_parameters()):
            if i >= num_params:
                break
            param_names.append(name)
            # Create test tensor with known values, matching original shape and dtype
            test_tensor = torch.full_like(tensor, 1.5, dtype=tensor.dtype).cuda()
            test_tensors.append(test_tensor)

        return list(zip(param_names, test_tensors))

148
    def test_utils_update_weights(self):
149
150
        """Test basic functionality of utils.update_weights"""

151
152
153
154
155
156
157
158
159
160
161
162
        async def async_test():
            # Create test parameters batch
            params_batch = self.create_test_params_batch(self.model, num_params=2)

            # Test the utils.update_weights function
            result = await update_weights(
                engine=self.engine,
                params_batch=params_batch,
                device_mesh_key=self.device_mesh_key,
                device_mesh=self.mesh,
                load_format=None,
            )
163

164
            self.assertIn("Success", result)
165

166
167
        # Run the async test
        asyncio.run(async_test())
168
169
170


if __name__ == "__main__":
171
    unittest.main()