Unverified Commit a78d8f8d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[CI] Fix test cases (#2137)

parent c5f86501
...@@ -24,6 +24,8 @@ import triton.language as tl ...@@ -24,6 +24,8 @@ import triton.language as tl
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip()
@triton.jit @triton.jit
def tanh(x): def tanh(x):
...@@ -501,7 +503,7 @@ def _decode_grouped_att_m_fwd( ...@@ -501,7 +503,7 @@ def _decode_grouped_att_m_fwd(
num_warps = 4 num_warps = 4
extra_kargs = {} extra_kargs = {}
if is_hip(): if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
...@@ -557,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd( ...@@ -557,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
extra_kargs = {} extra_kargs = {}
if is_hip(): if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
......
...@@ -29,6 +29,8 @@ is_cuda_available = torch.cuda.is_available() ...@@ -29,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
if is_cuda_available: if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
is_hip_ = is_hip()
@triton.jit @triton.jit
def tanh(x): def tanh(x):
...@@ -311,7 +313,7 @@ def extend_attention_fwd( ...@@ -311,7 +313,7 @@ def extend_attention_fwd(
num_stages = 1 num_stages = 1
extra_kargs = {} extra_kargs = {}
if is_hip(): if is_hip_:
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid]( _fwd_kernel[grid](
......
...@@ -242,15 +242,17 @@ class ModelRunner: ...@@ -242,15 +242,17 @@ class ModelRunner:
) )
return get_model(vllm_config=vllm_config) return get_model(vllm_config=vllm_config)
except ImportError: except ImportError:
return get_model( pass
model_config=self.vllm_model_config,
load_config=self.load_config, return get_model(
device_config=DeviceConfig(self.device), model_config=self.vllm_model_config,
parallel_config=None, load_config=self.load_config,
scheduler_config=None, device_config=DeviceConfig(self.device),
lora_config=None, parallel_config=None,
cache_config=None, scheduler_config=None,
) lora_config=None,
cache_config=None,
)
def get_model_config_params(self): def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__) sig = inspect.signature(VllmModelConfig.__init__)
......
...@@ -152,15 +152,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -152,15 +152,7 @@ class TestSRTEngine(unittest.TestCase):
self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
def test_7_engine_offline_throughput(self): def test_7_engine_cpu_offload(self):
server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)
def test_8_engine_cpu_offload(self):
prompt = "Today is a sunny day and I like" prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...@@ -190,6 +182,14 @@ class TestSRTEngine(unittest.TestCase): ...@@ -190,6 +182,14 @@ class TestSRTEngine(unittest.TestCase):
print(out2) print(out2)
self.assertEqual(out1, out2) self.assertEqual(out1, out2)
def test_8_engine_offline_throughput(self):
server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment