Unverified Commit 4ca43b06 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Add tensor.detach() back to update weight util (#8691)

parent ea93079b
...@@ -45,7 +45,7 @@ async def update_weights( ...@@ -45,7 +45,7 @@ async def update_weights(
( (
name, name,
MultiprocessingSerializer.serialize( MultiprocessingSerializer.serialize(
_preprocess_tensor_for_update_weights(tensor) _preprocess_tensor_for_update_weights(tensor.detach())
), ),
) )
for name, tensor in params_batch for name, tensor in params_batch
......
import asyncio import asyncio
import os import os
import unittest
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
from torch.distributed.device_mesh import init_device_mesh from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
...@@ -39,11 +38,29 @@ def setup_single_process_distributed(): ...@@ -39,11 +38,29 @@ def setup_single_process_distributed():
os.environ["LOCAL_RANK"] = "0" os.environ["LOCAL_RANK"] = "0"
class TestUtilsUpdateWeights: class TestUtilsUpdateWeights(unittest.TestCase):
"""Test class for utils.update_weights function""" """Test class for utils.update_weights function"""
@pytest.fixture(scope="class") @classmethod
def setup_distributed(self): def setUpClass(cls):
"""Setup distributed environment and test fixtures for the entire test class"""
cls.setup_distributed()
cls.setup_test_engine()
cls.setup_test_model()
cls.setup_device_mesh()
@classmethod
def tearDownClass(cls):
"""Cleanup after all tests"""
if hasattr(cls, "engine") and cls.engine:
cls.engine.shutdown()
# Cleanup distributed
if dist.is_initialized():
dist.destroy_process_group()
@classmethod
def setup_distributed(cls):
"""Setup distributed environment for testing""" """Setup distributed environment for testing"""
setup_single_process_distributed() setup_single_process_distributed()
...@@ -53,13 +70,15 @@ class TestUtilsUpdateWeights: ...@@ -53,13 +70,15 @@ class TestUtilsUpdateWeights:
backend="nccl" if torch.cuda.is_available() else "gloo" backend="nccl" if torch.cuda.is_available() else "gloo"
) )
except Exception as e: except Exception as e:
pytest.skip(f"Could not initialize distributed backend: {e}") raise unittest.SkipTest(
f"Could not initialize distributed backend: {e}"
)
rank = dist.get_rank() cls.rank = dist.get_rank()
world_size = dist.get_world_size() cls.world_size = dist.get_world_size()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(cls.rank % torch.cuda.device_count())
# Set up environment variables # Set up environment variables
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
...@@ -68,38 +87,26 @@ class TestUtilsUpdateWeights: ...@@ -68,38 +87,26 @@ class TestUtilsUpdateWeights:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO" os.environ["CUDA_MODULE_LOADING"] = "AUTO"
yield rank, world_size @classmethod
def setup_test_engine(cls):
# Cleanup
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(scope="class")
def test_engine(self, setup_distributed):
"""Setup test engine""" """Setup test engine"""
rank, world_size = setup_distributed if cls.rank == 0:
cls.engine = AsyncEngine(
if rank == 0:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
engine = AsyncEngine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
dtype="bfloat16", dtype="bfloat16",
mem_fraction_static=0.3, mem_fraction_static=0.3,
enable_memory_saver=True, enable_memory_saver=True,
tp_size=world_size, tp_size=cls.world_size,
disable_cuda_graph=True, disable_cuda_graph=False,
) )
yield engine
engine.shutdown()
else: else:
yield None cls.engine = None
@pytest.fixture(scope="class") @classmethod
def test_model(self): def setup_test_model(cls):
"""Load test model""" """Load test model"""
try: try:
model = AutoModelForCausalLM.from_pretrained( cls.model = AutoModelForCausalLM.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
device_map="cpu", device_map="cpu",
trust_remote_code=True, trust_remote_code=True,
...@@ -108,25 +115,20 @@ class TestUtilsUpdateWeights: ...@@ -108,25 +115,20 @@ class TestUtilsUpdateWeights:
torch.float16 if torch.cuda.is_available() else torch.float32 torch.float16 if torch.cuda.is_available() else torch.float32
), ),
) )
return model
except Exception as e: except Exception as e:
pytest.skip(f"Could not load test model: {e}") raise unittest.SkipTest(f"Could not load test model: {e}")
@pytest.fixture(scope="class") @classmethod
def device_mesh(self, setup_distributed): def setup_device_mesh(cls):
"""Create device mesh for testing""" """Create device mesh for testing"""
rank, world_size = setup_distributed
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("CUDA not available for device mesh") raise unittest.SkipTest("CUDA not available for device mesh")
device_mesh_key = "tp" cls.device_mesh_key = "tp"
mesh = init_device_mesh( cls.mesh = init_device_mesh(
"cuda", (world_size,), mesh_dim_names=(device_mesh_key,) "cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
) )
return device_mesh_key, mesh
def create_test_params_batch(self, model, num_params=64): def create_test_params_batch(self, model, num_params=64):
"""Create a batch of test parameters from the model""" """Create a batch of test parameters from the model"""
param_names = [] param_names = []
...@@ -143,31 +145,27 @@ class TestUtilsUpdateWeights: ...@@ -143,31 +145,27 @@ class TestUtilsUpdateWeights:
return list(zip(param_names, test_tensors)) return list(zip(param_names, test_tensors))
@pytest.mark.asyncio def test_utils_update_weights(self):
async def test_utils_update_weights(
self, setup_distributed, test_engine, test_model, device_mesh
):
"""Test basic functionality of utils.update_weights""" """Test basic functionality of utils.update_weights"""
rank, world_size = setup_distributed
device_mesh_key, mesh = device_mesh
async def async_test():
# Create test parameters batch # Create test parameters batch
params_batch = self.create_test_params_batch(test_model, num_params=2) params_batch = self.create_test_params_batch(self.model, num_params=2)
print(
f"Rank {rank} testing utils.update_weights with {len(params_batch)} parameters"
)
# Test the utils.update_weights function # Test the utils.update_weights function
result = await update_weights( result = await update_weights(
engine=test_engine, engine=self.engine,
params_batch=params_batch, params_batch=params_batch,
device_mesh_key=device_mesh_key, device_mesh_key=self.device_mesh_key,
device_mesh=mesh, device_mesh=self.mesh,
load_format=None, load_format=None,
) )
assert "Success" in result self.assertIn("Success", result)
# Run the async test
asyncio.run(async_test())
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment