Unverified Commit 1bfa511b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[CI] Fix ci tests (#2284)

parent f5b5f2bf
...@@ -16,7 +16,7 @@ from sglang.test.test_utils import ( ...@@ -16,7 +16,7 @@ from sglang.test.test_utils import (
from sglang.utils import terminate_process from sglang.utils import terminate_process
class TestUpdateWeights(unittest.TestCase): class TestGetParameterByName(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase):
if self.process: if self.process:
terminate_process(self.process) terminate_process(self.process)
def assert_update_weights_all_close(self, param_name, truncate_size): def assert_weights_all_close(self, param_name, truncate_size):
print( print(
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}" f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
) )
...@@ -87,12 +87,12 @@ class TestUpdateWeights(unittest.TestCase): ...@@ -87,12 +87,12 @@ class TestUpdateWeights(unittest.TestCase):
@staticmethod @staticmethod
def _process_return(ret): def _process_return(ret):
if isinstance(ret, list) and len(ret) == 2: if isinstance(ret, list) and len(ret) == 2:
print(f"running assert_allclose on data parallel") print("running assert_allclose on data parallel")
np.testing.assert_allclose(ret[0], ret[1]) np.testing.assert_allclose(ret[0], ret[1])
return np.array(ret[0]) return np.array(ret[0])
return np.array(ret) return np.array(ret)
def test_update_weights_unexist_model(self): def test_get_parameters_by_name(self):
test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)] test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)]
if torch.cuda.device_count() >= 2: if torch.cuda.device_count() >= 2:
...@@ -120,7 +120,7 @@ class TestUpdateWeights(unittest.TestCase): ...@@ -120,7 +120,7 @@ class TestUpdateWeights(unittest.TestCase):
for test_suit in test_suits: for test_suit in test_suits:
self.init_backend(*test_suit) self.init_backend(*test_suit)
for param_name in parameters: for param_name in parameters:
self.assert_update_weights_all_close(param_name, 100) self.assert_weights_all_close(param_name, 100)
self.close_engine_and_server() self.close_engine_and_server()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment