"torchvision/csrc/cpu/decoder/subtitle_stream.cpp" did not exist on "8b9859d3aeebcd37e6a284fc751c58569857f7be"
Unverified Commit 7577f0e4 authored by Cao E's avatar Cao E Committed by GitHub
Browse files

Add graph runner support with torch compile on CPU (#7843)

parent 8cda5a62
...@@ -70,7 +70,7 @@ jobs: ...@@ -70,7 +70,7 @@ jobs:
- name: Run unit tests - name: Run unit tests
if: steps.check_amx.outcome == 'success' if: steps.check_amx.outcome == 'success'
timeout-minutes: 30 timeout-minutes: 36
run: | run: |
docker exec -w /sglang-checkout/ ci_sglang_xeon \ docker exec -w /sglang-checkout/ ci_sglang_xeon \
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu" bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
......
...@@ -134,7 +134,12 @@ Notes: ...@@ -134,7 +134,12 @@ Notes:
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253" export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
``` ```
3. A warmup step is automatically triggered when the service is started. 3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`.
To specify the maximum batch size when using torch compile, set the flag `--torch-compile-max-bs`.
For example, `--enable-torch-compile --torch-compile-max-bs 4` means using torch compile and setting the
maximum batch size to 4.
4. A warmup step is automatically triggered when the service is started.
The server is ready when you see the log `The server is fired up and ready to roll!`. The server is ready when you see the log `The server is fired up and ready to roll!`.
## Benchmarking with Requests ## Benchmarking with Requests
......
...@@ -64,6 +64,9 @@ class GraphCaptureContext: ...@@ -64,6 +64,9 @@ class GraphCaptureContext:
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
def _split_tensor_dict( def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]] tensor_dict: Dict[str, Union[torch.Tensor, Any]]
...@@ -489,9 +492,7 @@ class GroupCoordinator: ...@@ -489,9 +492,7 @@ class GroupCoordinator:
if input_.is_cpu: if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size): if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce( torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
input_, torch.distributed.ReduceOp.SUM
)
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
return input_ return input_
......
...@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend): ...@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
max_extend_len = torch.max(forward_batch.extend_seq_lens).item() max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (attn_logits, max_extend_len) self.forward_metadata = (attn_logits, max_extend_len)
def get_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, self,
q, q,
......
...@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
_is_cpu_amx_available _is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support" ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"]) _amx_process_weight_after_loading(layer, ["weight"])
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
return return
else: else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
......
...@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
_is_cpu_amx_available _is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support" ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"]) _amx_process_weight_after_loading(layer, ["weight"])
return else:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights( def create_weights(
...@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
_is_cpu_amx_available _is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support" ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return else:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter( layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False layer.w13_weight_scale.data, requires_grad=False
) )
......
...@@ -414,7 +414,7 @@ class Scheduler( ...@@ -414,7 +414,7 @@ class Scheduler(
f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, " f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
f"available_gpu_mem={avail_mem:.2f} GB" f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
) )
# Init memory pool and cache # Init memory pool and cache
...@@ -2252,10 +2252,9 @@ class Scheduler( ...@@ -2252,10 +2252,9 @@ class Scheduler(
"token_capacity": int(self.max_total_num_tokens), "token_capacity": int(self.max_total_num_tokens),
} }
if not _is_cpu: ret["memory_usage"]["graph"] = round(
ret["memory_usage"]["cuda_graph"] = round( self.tp_worker.worker.model_runner.graph_mem_usage, 2
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2 )
)
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = ( ret["avg_spec_accept_length"] = (
......
...@@ -214,7 +214,7 @@ class SchedulerMetricsMixin: ...@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += ( msg += (
f"cuda graph: {can_run_cuda_graph}, " f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
) )
......
This diff is collapsed.
...@@ -132,6 +132,9 @@ class ForwardMode(IntEnum): ...@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
or self == ForwardMode.IDLE or self == ForwardMode.IDLE
) )
def is_cpu_graph(self):
return self == ForwardMode.DECODE
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
......
...@@ -20,6 +20,7 @@ import json ...@@ -20,6 +20,7 @@ import json
import logging import logging
import os import os
import time import time
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
SWAKVPool, SWAKVPool,
) )
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
...@@ -360,12 +362,12 @@ class ModelRunner: ...@@ -360,12 +362,12 @@ class ModelRunner:
self.init_cublas() self.init_cublas()
self.init_attention_backend() self.init_attention_backend()
self.init_device_graphs() self.init_device_graphs()
elif self.device == "npu": elif self.device in ["npu", "cpu"]:
self.init_attention_backend() self.init_attention_backend()
self.init_device_graphs() self.init_device_graphs()
else: else:
self.graph_runner = None self.graph_runner = None
self.cuda_graph_mem_usage = 0 self.graph_mem_usage = 0
self.init_attention_backend() self.init_attention_backend()
# auxiliary hidden capture mode. TODO: expose this to server args? # auxiliary hidden capture mode. TODO: expose this to server args?
...@@ -608,6 +610,11 @@ class ModelRunner: ...@@ -608,6 +610,11 @@ class ModelRunner:
# Set local size to hint SGLang to use shared memory based AllReduce # Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size) os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
@torch.library.register_fake("sgl_kernel::shm_allgather")
def _(data, dim):
return torch.cat([data] * self.tp_size, dim=dim)
else: else:
logger.warning( logger.warning(
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
...@@ -1619,30 +1626,39 @@ class ModelRunner: ...@@ -1619,30 +1626,39 @@ class ModelRunner:
) )
def init_device_graphs(self): def init_device_graphs(self):
"""Capture cuda graphs.""" """Capture device graphs."""
self.graph_runner = None self.graph_runner = None
self.cuda_graph_mem_usage = 0 self.graph_mem_usage = 0
if not self.is_generation: if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return return
if self.server_args.disable_cuda_graph: if self.device != "cpu" and self.server_args.disable_cuda_graph:
return
if self.device == "cpu" and not self.server_args.enable_torch_compile:
return return
tic = time.perf_counter() tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id) before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
) )
self.graph_runner = ( graph_runners = defaultdict(
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) lambda: CudaGraphRunner,
{
"cpu": CPUGraphRunner,
"npu": NPUGraphRunner,
},
) )
self.graph_runner = graph_runners[self.device](self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem self.graph_mem_usage = before_mem - after_mem
logger.info( logger.info(
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. " f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
) )
def init_threads_binding(self): def init_threads_binding(self):
...@@ -1787,18 +1803,24 @@ class ModelRunner: ...@@ -1787,18 +1803,24 @@ class ModelRunner:
reinit_attn_backend: bool = False, reinit_attn_backend: bool = False,
split_forward_count: int = 1, split_forward_count: int = 1,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool( mode_check = (
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cpu_graph
if self.device == "cpu"
else forward_batch.forward_mode.is_cuda_graph
)
can_run_graph = bool(
mode_check()
and self.graph_runner and self.graph_runner
and self.graph_runner.can_run(forward_batch) and self.graph_runner.can_run(forward_batch)
) )
if can_run_cuda_graph:
if can_run_graph:
ret = self.graph_runner.replay( ret = self.graph_runner.replay(
forward_batch, forward_batch,
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
return ret, can_run_cuda_graph return ret, can_run_graph
# For MLP sync # For MLP sync
if forward_batch.global_num_tokens_cpu is not None: if forward_batch.global_num_tokens_cpu is not None:
...@@ -1833,7 +1855,7 @@ class ModelRunner: ...@@ -1833,7 +1855,7 @@ class ModelRunner:
): ):
forward_batch.post_forward_mlp_sync_batch(ret) forward_batch.post_forward_mlp_sync_batch(ret)
return ret, can_run_cuda_graph return ret, can_run_graph
def _preprocess_logits( def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
......
...@@ -230,8 +230,16 @@ except: ...@@ -230,8 +230,16 @@ except:
is_intel_amx_backend_available = False is_intel_amx_backend_available = False
try:
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
# to support torch compile
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
except:
is_amx_tile_supported = False
def cpu_has_amx_support(): def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available return is_amx_tile_supported and is_intel_amx_backend_available
def use_intel_amx_backend(layer): def use_intel_amx_backend(layer):
......
...@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu); m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor"); m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu); m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"); m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu); m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
// topk // topk
...@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode // decode
m.def( m.def(
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, " "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, " "Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()"); "float logit_cap) -> ()");
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu); m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
// extend // extend
m.def( m.def(
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, " "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor " "Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"); "extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu); m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
...@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant); m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
// bmm // bmm
m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"); m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu); m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
// moe // moe
...@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce // all reduce
m.def("initialize(int size, int rank) -> ()"); m.def("initialize(int size, int rank) -> ()");
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()"); m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce); m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, int dim) -> Tensor"); m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather); m.impl("shm_allgather", torch::kCPU, &shm_allgather);
......
...@@ -276,6 +276,7 @@ suite_xeon = { ...@@ -276,6 +276,7 @@ suite_xeon = {
TestFile("cpu/test_shared_expert.py"), TestFile("cpu/test_shared_expert.py"),
TestFile("cpu/test_topk.py"), TestFile("cpu/test_topk.py"),
TestFile("test_intel_amx_attention_backend.py"), TestFile("test_intel_amx_attention_backend.py"),
TestFile("test_cpu_graph.py"),
], ],
} }
......
"""
Usage:
python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu
"""
import copy
import os
import unittest
from types import SimpleNamespace
from test_intel_amx_attention_backend import intel_amx_benchmark
from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
class TestCPUGraph(CustomTestCase):
@intel_amx_benchmark(
extra_args=[
"--batch-size",
"1",
"--mem-fraction-static",
"0.05",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
],
min_throughput=10,
)
def test_latency_torch_compile_cpu(self):
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
def test_mmlu_torch_compile_cpu(self):
model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
cpu_ids_by_node = get_cpu_ids_by_node()
n_numa_node = len(cpu_ids_by_node)
env = copy.deepcopy(os.environ)
env["SGLANG_CPU_OMP_THREADS_BIND"] = "all"
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--attention-backend",
"intel_amx",
"--mem-fraction-static",
"0.05",
"--disable-radix",
"--trust-remote-code",
"--disable-overlap-schedule",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
"--tp",
f"{n_numa_node}",
],
env=env,
)
try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
if is_in_ci():
self.assertGreater(metrics["score"], 0.45)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
...@@ -3,7 +3,6 @@ Usage: ...@@ -3,7 +3,6 @@ Usage:
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
""" """
import os
import unittest import unittest
from functools import wraps from functools import wraps
from types import SimpleNamespace from types import SimpleNamespace
...@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None): ...@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
"intel_amx", "intel_amx",
"--disable-radix", "--disable-radix",
"--trust-remote-code", "--trust-remote-code",
"--batch-size",
"4",
] ]
full_args = common_args + (extra_args or []) full_args = common_args + (extra_args or [])
...@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None): ...@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
class TestIntelAMXAttnBackend(CustomTestCase): class TestIntelAMXAttnBackend(CustomTestCase):
@intel_amx_benchmark(min_throughput=10) @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=10)
def test_latency_mla_model(self): def test_latency_mla_model(self):
return DEFAULT_MLA_MODEL_NAME_FOR_TEST return DEFAULT_MLA_MODEL_NAME_FOR_TEST
@intel_amx_benchmark(min_throughput=40) @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=40)
def test_latency_default_model(self): def test_latency_default_model(self):
return DEFAULT_MODEL_NAME_FOR_TEST return DEFAULT_MODEL_NAME_FOR_TEST
@intel_amx_benchmark(min_throughput=150) @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=150)
def test_latency_fp8_qwen(self): def test_latency_fp8_qwen(self):
return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
@intel_amx_benchmark(min_throughput=50) @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=50)
def test_latency_fp8_moe_model(self): def test_latency_fp8_moe_model(self):
return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
@intel_amx_benchmark(extra_args=["--quantization", "w8a8_int8"], min_throughput=100) @intel_amx_benchmark(
extra_args=["--batch-size", "4", "--quantization", "w8a8_int8"],
min_throughput=100,
)
def test_latency_w8a8_default_model(self): def test_latency_w8a8_default_model(self):
return DEFAULT_MODEL_NAME_FOR_TEST_W8A8 return DEFAULT_MODEL_NAME_FOR_TEST_W8A8
@intel_amx_benchmark( @intel_amx_benchmark(
extra_args=[ extra_args=[
"--batch-size",
"4",
"--quantization", "--quantization",
"w8a8_int8", "w8a8_int8",
"--mem-fraction-static", "--mem-fraction-static",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment