"vscode:/vscode.git/clone" did not exist on "2f80bd9f0e1ff7e6fb19d2fe2ca3d1587bf1d0c7"
Unverified Commit 7577f0e4 authored by Cao E's avatar Cao E Committed by GitHub
Browse files

Add graph runner support with torch compile on CPU (#7843)

parent 8cda5a62
......@@ -70,7 +70,7 @@ jobs:
- name: Run unit tests
if: steps.check_amx.outcome == 'success'
timeout-minutes: 30
timeout-minutes: 36
run: |
docker exec -w /sglang-checkout/ ci_sglang_xeon \
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
......
......@@ -134,7 +134,12 @@ Notes:
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
```
3. A warmup step is automatically triggered when the service is started.
3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`.
To specify the maximum batch size when using torch compile, set the flag `--torch-compile-max-bs`.
For example, `--enable-torch-compile --torch-compile-max-bs 4` means using torch compile and setting the
maximum batch size to 4.
4. A warmup step is automatically triggered when the service is started.
The server is ready when you see the log `The server is fired up and ready to roll!`.
## Benchmarking with Requests
......
......@@ -64,6 +64,9 @@ class GraphCaptureContext:
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
......@@ -489,9 +492,7 @@ class GroupCoordinator:
if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce(
input_, torch.distributed.ReduceOp.SUM
)
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
......
......@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (attn_logits, max_extend_len)
def get_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q,
......
......@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
......
......@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
return
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
else:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights(
......@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
_is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
else:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
......
......@@ -414,7 +414,7 @@ class Scheduler(
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}, "
f"available_gpu_mem={avail_mem:.2f} GB"
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
)
# Init memory pool and cache
......@@ -2252,10 +2252,9 @@ class Scheduler(
"token_capacity": int(self.max_total_num_tokens),
}
if not _is_cpu:
ret["memory_usage"]["cuda_graph"] = round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
)
ret["memory_usage"]["graph"] = round(
self.tp_worker.worker.model_runner.graph_mem_usage, 2
)
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
......
......@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
......
This diff is collapsed.
......@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
or self == ForwardMode.IDLE
)
def is_cpu_graph(self):
return self == ForwardMode.DECODE
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
......
......@@ -20,6 +20,7 @@ import json
import logging
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
......@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
......@@ -360,12 +362,12 @@ class ModelRunner:
self.init_cublas()
self.init_attention_backend()
self.init_device_graphs()
elif self.device == "npu":
elif self.device in ["npu", "cpu"]:
self.init_attention_backend()
self.init_device_graphs()
else:
self.graph_runner = None
self.cuda_graph_mem_usage = 0
self.graph_mem_usage = 0
self.init_attention_backend()
# auxiliary hidden capture mode. TODO: expose this to server args?
......@@ -608,6 +610,11 @@ class ModelRunner:
# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
@torch.library.register_fake("sgl_kernel::shm_allgather")
def _(data, dim):
return torch.cat([data] * self.tp_size, dim=dim)
else:
logger.warning(
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
......@@ -1619,30 +1626,39 @@ class ModelRunner:
)
def init_device_graphs(self):
"""Capture cuda graphs."""
"""Capture device graphs."""
self.graph_runner = None
self.cuda_graph_mem_usage = 0
self.graph_mem_usage = 0
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
if self.server_args.disable_cuda_graph:
if self.device != "cpu" and self.server_args.disable_cuda_graph:
return
if self.device == "cpu" and not self.server_args.enable_torch_compile:
return
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.graph_runner = (
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
graph_runners = defaultdict(
lambda: CudaGraphRunner,
{
"cpu": CPUGraphRunner,
"npu": NPUGraphRunner,
},
)
self.graph_runner = graph_runners[self.device](self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
self.graph_mem_usage = before_mem - after_mem
logger.info(
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
)
def init_threads_binding(self):
......@@ -1787,18 +1803,24 @@ class ModelRunner:
reinit_attn_backend: bool = False,
split_forward_count: int = 1,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph()
mode_check = (
forward_batch.forward_mode.is_cpu_graph
if self.device == "cpu"
else forward_batch.forward_mode.is_cuda_graph
)
can_run_graph = bool(
mode_check()
and self.graph_runner
and self.graph_runner.can_run(forward_batch)
)
if can_run_cuda_graph:
if can_run_graph:
ret = self.graph_runner.replay(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
return ret, can_run_cuda_graph
return ret, can_run_graph
# For MLP sync
if forward_batch.global_num_tokens_cpu is not None:
......@@ -1833,7 +1855,7 @@ class ModelRunner:
):
forward_batch.post_forward_mlp_sync_batch(ret)
return ret, can_run_cuda_graph
return ret, can_run_graph
def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
......
......@@ -230,8 +230,16 @@ except:
is_intel_amx_backend_available = False
try:
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
# to support torch compile
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
except:
is_amx_tile_supported = False
def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
return is_amx_tile_supported and is_intel_amx_backend_available
def use_intel_amx_backend(layer):
......
......@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
// topk
......@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode
m.def(
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()");
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
// extend
m.def(
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
......@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
// bmm
m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
// moe
......@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce
m.def("initialize(int size, int rank) -> ()");
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
......
......@@ -276,6 +276,7 @@ suite_xeon = {
TestFile("cpu/test_shared_expert.py"),
TestFile("cpu/test_topk.py"),
TestFile("test_intel_amx_attention_backend.py"),
TestFile("test_cpu_graph.py"),
],
}
......
"""
Usage:
python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu
"""
import copy
import os
import unittest
from types import SimpleNamespace
from test_intel_amx_attention_backend import intel_amx_benchmark
from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
class TestCPUGraph(CustomTestCase):
@intel_amx_benchmark(
extra_args=[
"--batch-size",
"1",
"--mem-fraction-static",
"0.05",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
],
min_throughput=10,
)
def test_latency_torch_compile_cpu(self):
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
def test_mmlu_torch_compile_cpu(self):
model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
cpu_ids_by_node = get_cpu_ids_by_node()
n_numa_node = len(cpu_ids_by_node)
env = copy.deepcopy(os.environ)
env["SGLANG_CPU_OMP_THREADS_BIND"] = "all"
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--attention-backend",
"intel_amx",
"--mem-fraction-static",
"0.05",
"--disable-radix",
"--trust-remote-code",
"--disable-overlap-schedule",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
"--tp",
f"{n_numa_node}",
],
env=env,
)
try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
if is_in_ci():
self.assertGreater(metrics["score"], 0.45)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
......@@ -3,7 +3,6 @@ Usage:
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
"""
import os
import unittest
from functools import wraps
from types import SimpleNamespace
......@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
"intel_amx",
"--disable-radix",
"--trust-remote-code",
"--batch-size",
"4",
]
full_args = common_args + (extra_args or [])
......@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
class TestIntelAMXAttnBackend(CustomTestCase):
@intel_amx_benchmark(min_throughput=10)
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=10)
def test_latency_mla_model(self):
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
@intel_amx_benchmark(min_throughput=40)
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=40)
def test_latency_default_model(self):
return DEFAULT_MODEL_NAME_FOR_TEST
@intel_amx_benchmark(min_throughput=150)
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=150)
def test_latency_fp8_qwen(self):
return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
@intel_amx_benchmark(min_throughput=50)
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=50)
def test_latency_fp8_moe_model(self):
return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
@intel_amx_benchmark(extra_args=["--quantization", "w8a8_int8"], min_throughput=100)
@intel_amx_benchmark(
extra_args=["--batch-size", "4", "--quantization", "w8a8_int8"],
min_throughput=100,
)
def test_latency_w8a8_default_model(self):
return DEFAULT_MODEL_NAME_FOR_TEST_W8A8
@intel_amx_benchmark(
extra_args=[
"--batch-size",
"4",
"--quantization",
"w8a8_int8",
"--mem-fraction-static",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment